File tree Expand file tree Collapse file tree 2 files changed +19
-2
lines changed
src/llmcompressor/modifiers/transform/spinquant Expand file tree Collapse file tree 2 files changed +19
-2
lines changed Original file line number Diff line number Diff line change @@ -236,8 +236,22 @@ def _create_r2_scheme(self, head_dim: int) -> TransformScheme:
236236 )
237237
238238 def _create_r3_scheme (self , head_dim : int ) -> TransformScheme :
239- raise NotImplementedError (
240- "SpinQuant R3 rotations will be added in a future release"
239+ return TransformScheme (
240+ type = self .transform_type ,
241+ randomize = self .randomize ,
242+ requires_grad = self .learnable ,
243+ precision = self .precision ,
244+ head_dim = head_dim ,
245+ apply = [
246+ TransformArgs (
247+ targets = [self .mappings .attn ],
248+ location = "q_attn" ,
249+ ),
250+ TransformArgs (
251+ targets = [self .mappings .attn ],
252+ location = "k_cache" ,
253+ ),
254+ ],
241255 )
242256
243257 def _create_r4_scheme (self ) -> TransformScheme :
Original file line number Diff line number Diff line change @@ -14,6 +14,7 @@ class SpinQuantMapping(BaseModel):
1414 layers (https://arxiv.org/pdf/2405.16406 Fig. 1).
1515
1616 :param embedding: name or regex of embedding layer
17+ :param attn: name or regex of attention block in decoder layer
1718 :param attn_q: name or regex of q_proj layer in attention block
1819 :param attn_k: name or regex of k_proj layer in attention block
1920 :param attn_v: name or regex of v_proj layer in attention block
@@ -29,6 +30,7 @@ class SpinQuantMapping(BaseModel):
2930
3031 embedding : str
3132
33+ attn : str
3234 attn_q : str
3335 attn_k : str
3436 attn_v : str
@@ -50,6 +52,7 @@ def cast_to_list(cls, value):
5052
5153_default_mappings = SpinQuantMapping (
5254 embedding = "re:.*embed_tokens$" ,
55+ attn = "re:.*self_attn$" ,
5356 attn_q = "re:.*q_proj$" ,
5457 attn_k = "re:.*k_proj$" ,
5558 attn_v = "re:.*v_proj$" ,
You can’t perform that action at this time.
0 commit comments