Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add QuiP quant support #217

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 89 additions & 48 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import math
from exllamav2 import ext
from exllamav2.ext import exllamav2_ext as ext_c
from exllamav2.quip_linear import QuipLinear
# import xformers.ops as xops
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak

Expand All @@ -34,6 +35,11 @@ class ExLlamaV2Attention(ExLlamaV2Module):
k_proj: ExLlamaV2Linear or None
v_proj: ExLlamaV2Linear or None
o_proj: ExLlamaV2Linear
qkv_proj: ExLlamaV2Linear or None
k_scale: torch.tensor or None
o_scale: torch.tensor or None
q_scale: torch.tensor or None
v_scale: torch.tensor or None

name: str = "Attention"
submodules: list
Expand All @@ -59,29 +65,45 @@ def __init__(self, model, key, layer_idx):
hidden_size = self.model.config.hidden_size

self.input_layernorm = ExLlamaV2RMSNorm(model, key + ".input_layernorm")
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, self.model.config.num_attention_heads * self.model.config.head_dim, False)
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, False)
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, False)
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", self.model.config.num_attention_heads * self.model.config.head_dim, hidden_size, False)

self.submodules = [self.input_layernorm,
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj]
self.submodules = [self.input_layernorm]
if model.config.is_quip:
self.qkv_proj = QuipLinear(model,
key + ".self_attn.qkv_proj",
hidden_size,
(self.model.config.num_attention_heads * self.model.config.head_dim) +
(self.model.config.num_key_value_heads * self.model.config.head_dim) +
(self.model.config.num_key_value_heads * self.model.config.head_dim))

self.o_proj = QuipLinear(model, key + ".self_attn.o_proj",
self.model.config.num_attention_heads * self.model.config.head_dim,
hidden_size)
self.submodules += [self.qkv_proj, self.o_proj]
else:
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, self.model.config.num_attention_heads * self.model.config.head_dim, False)
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, False)
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, False)
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", self.model.config.num_attention_heads * self.model.config.head_dim, hidden_size, False)
self.submodules += [self.q_proj, self.k_proj, self.v_proj, self.o_proj]


def load(self):
if self.model.config.is_quip:
w = self.load_weight()
self.k_scale = w['k_scale']
self.o_scale = w['o_scale']
self.q_scale = w['q_scale']
self.v_scale = w['v_scale']

qkv_embed = self.model.config.qkv_embed and self.layer_idx == 0

self.input_layernorm.load()
self.q_proj.load()
self.k_proj.load()
self.v_proj.load()
if hasattr(self, 'input_layernorm') and self.input_layernorm is not None: self.input_layernorm.load()
if hasattr(self, 'q_proj') and self.q_proj is not None: self.q_proj.load()
if hasattr(self, 'k_proj') and self.k_proj is not None: self.k_proj.load()
if hasattr(self, 'v_proj') and self.v_proj is not None: self.v_proj.load()
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: self.qkv_proj.load()
self.o_proj.load()

if self.q_proj.is_quant():
if hasattr(self, 'q_proj') and self.q_proj is not None and self.q_proj.is_quant():

assert self.k_proj.is_quant() and self.v_proj.is_quant() and self.o_proj.is_quant(), "Partially quantized attention layer"

Expand Down Expand Up @@ -116,15 +138,15 @@ def load(self):

embedding = self.model.modules[0]
assert isinstance(embedding, ExLlamaV2Embedding)
q = self.q_proj.get_weight_tensor_dq()
k = self.k_proj.get_weight_tensor_dq()
v = self.v_proj.get_weight_tensor_dq()
q = self.q_proj.get_weight_tensor_dq() if hasattr(self, 'q_proj') and self.q_proj is not None else None
k = self.k_proj.get_weight_tensor_dq() if hasattr(self, 'k_proj') and self.k_proj is not None else None
v = self.v_proj.get_weight_tensor_dq() if hasattr(self, 'v_proj') and self.v_proj is not None else None
norm = self.input_layernorm
embedding.make_qkv(norm, q, k, v)

self.q_proj.unload(); self.q_proj = None
self.k_proj.unload(); self.k_proj = None
self.v_proj.unload(); self.v_proj = None
if hasattr(self, 'q_proj') and self.q_proj is not None: self.q_proj.unload(); self.q_proj = None
if hasattr(self, 'k_proj') and self.v_proj is not None: self.k_proj.unload(); self.k_proj = None
if hasattr(self, 'v_proj') and self.v_proj is not None: self.v_proj.unload(); self.v_proj = None
self.input_layernorm.unload(); self.input_layernorm = None


Expand All @@ -133,10 +155,11 @@ def unload(self):
ext_c.free_q_attn(self.q_handle)
self.q_handle = None

if self.input_layernorm is not None: self.input_layernorm.unload()
if self.q_proj is not None: self.q_proj.unload()
if self.k_proj is not None: self.k_proj.unload()
if self.v_proj is not None: self.v_proj.unload()
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: self.qkv_proj.unload()
if hasattr(self, 'input_layernorm') and self.input_layernorm is not None: self.input_layernorm.unload()
if hasattr(self, 'q_proj') and self.q_proj is not None: self.q_proj.unload()
if hasattr(self, 'k_proj') and self.k_proj is not None: self.k_proj.unload()
if hasattr(self, 'v_proj') and self.v_proj is not None: self.v_proj.unload()
self.o_proj.unload()


Expand All @@ -149,9 +172,10 @@ def weight_footprint(self, qkv_embed = False):
else:

return self.input_layernorm.weight_footprint() + \
self.q_proj.weight_footprint() + \
self.k_proj.weight_footprint() + \
self.v_proj.weight_footprint() + \
self.q_proj.weight_footprint() if hasattr(self, 'q_proj') and self.q_proj is not None else 0 + \
self.k_proj.weight_footprint() if hasattr(self, 'k_proj') and self.k_proj is not None else 0 + \
self.v_proj.weight_footprint() if hasattr(self, 'v_proj') and self.v_proj is not None else 0 + \
self.qkv_proj.weight_footprint() if hasattr(self, 'qkv_proj') and self.qkv_proj is not None else 0 + \
self.o_proj.weight_footprint()


Expand Down Expand Up @@ -193,10 +217,10 @@ def temp_v_size(self):


def temp_dq_size(self):

return max(self.q_proj.temp_dq_size(),
self.k_proj.temp_dq_size(),
self.v_proj.temp_dq_size(),
return max(self.q_proj.temp_dq_size() if hasattr(self, 'q_proj') and self.q_proj is not None else 0,
self.k_proj.temp_dq_size() if hasattr(self, 'k_proj') and self.k_proj is not None else 0,
self.v_proj.temp_dq_size() if hasattr(self, 'v_proj') and self.v_proj is not None else 0,
self.qkv_proj.temp_dq_size() if hasattr(self, 'qkv_proj') and self.qkv_proj is not None else 0,
self.o_proj.temp_dq_size())


Expand All @@ -222,9 +246,10 @@ def set_device_idx(self, idx):
super().set_device_idx(idx)

self.input_layernorm.set_device_idx(idx)
self.q_proj.set_device_idx(idx)
self.k_proj.set_device_idx(idx)
self.v_proj.set_device_idx(idx)
if hasattr(self, 'q_proj') and self.q_proj is not None: self.q_proj.set_device_idx(idx)
if hasattr(self, 'k_proj') and self.k_proj is not None: self.k_proj.set_device_idx(idx)
if hasattr(self, 'v_proj') and self.v_proj is not None: self.v_proj.set_device_idx(idx)
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: self.qkv_proj.set_device_idx(idx)
self.o_proj.set_device_idx(idx)


Expand Down Expand Up @@ -517,21 +542,37 @@ def forward_torch(self, hidden_states, cache = None, attn_mask = None, past_len
residual = hidden_states
post_norm = self.input_layernorm.forward(hidden_states)

query_states_im = self.q_proj.forward(post_norm, loras = loras)
key_states_im = self.k_proj.forward(post_norm, loras = loras)
value_states_im = self.v_proj.forward(post_norm, loras = loras)
if self.model.config.is_quip:
qkv_states = self.qkv_proj.forward(post_norm.to(torch.float32), loras = loras)
query_states = self.q_scale * qkv_states[..., 0:(num_attention_heads * head_dim)]
key_states = self.k_scale * qkv_states[..., (
num_attention_heads * head_dim):(
(num_attention_heads * head_dim) +
(num_key_value_heads * head_dim))]
value_states = self.v_scale * qkv_states[..., (
(num_attention_heads * head_dim) +
(num_key_value_heads * head_dim)):(
(num_attention_heads * head_dim) +
(num_key_value_heads * head_dim) +
(num_key_value_heads * head_dim))]
else:
query_states_im = self.q_proj.forward(post_norm, loras = loras)
key_states_im = self.k_proj.forward(post_norm, loras = loras)
value_states_im = self.v_proj.forward(post_norm, loras = loras)

if intermediates:

query_states = query_states_im.clone()
key_states = key_states_im.clone()
value_states = value_states_im.clone()

if intermediates:
else:

query_states = query_states_im.clone()
key_states = key_states_im.clone()
value_states = value_states_im.clone()
query_states = query_states_im
key_states = key_states_im
value_states = value_states_im

else:

query_states = query_states_im
key_states = key_states_im
value_states = value_states_im

# Alternative, for embedded QKV

Expand Down Expand Up @@ -606,7 +647,7 @@ def forward_torch(self, hidden_states, cache = None, attn_mask = None, past_len

# Output projection

attn_proj = self.o_proj.forward(attn_output, loras = loras)
attn_proj = self.o_scale * self.o_proj.forward(attn_output, loras = loras) if self.model.config.is_quip else self.o_proj.forward(attn_output)

# Add residual connection

Expand Down Expand Up @@ -651,5 +692,5 @@ def update_loras(self):


def is_quant(self):
return self.q_handle is not None
return self.q_handle is not None or self.qkv_proj is not None

18 changes: 18 additions & 0 deletions exllamav2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class ExLlamaV2Config:
head_dim: int = 128 # Constant for all Llama models, except 3b

qkv_embed: bool = False
is_quip: bool = False
quip_params: dict = None


def __init__(self):
Expand Down Expand Up @@ -106,6 +108,9 @@ def prepare(self, no_tensors = False):
if "max_sequence_length" in read_config: self.max_seq_len = read_config["max_sequence_length"]
elif "max_position_embeddings" in read_config: self.max_seq_len = read_config["max_position_embeddings"]

self.is_quip = True if 'quip_params' in read_config else False
if self.is_quip: self.quip_params = read_config['quip_params']

# Model dimensions

self.head_dim = self.hidden_size // self.num_attention_heads
Expand Down Expand Up @@ -140,6 +145,19 @@ def prepare(self, no_tensors = False):
["mlp.down_proj"],
["mlp.gate_proj"],
["mlp.up_proj"]
] if not self.is_quip else [
["input_layernorm", "ln1"],
["post_attention_layernorm", "ln2"],
["self_attn.qkv_proj"],
["self_attn.o_proj"],
["self_attn.k_scale"],
["self_attn.o_scale"],
["self_attn.q_scale"],
["mlp.down_proj"],
["mlp.upgate_proj"],
["mlp.down_scale"],
["mlp.gate_scale"],
["mlp.up_scale"]
]

expect_keys = []
Expand Down
Loading