Skip to content

Commit 80b2877

Browse files
committed
implement R3
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 87e19de commit 80b2877

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,22 @@ def _create_r2_scheme(self, head_dim: int) -> TransformScheme:
236236
)
237237

238238
def _create_r3_scheme(self, head_dim: int) -> TransformScheme:
239-
raise NotImplementedError(
240-
"SpinQuant R3 rotations will be added in a future release"
239+
return TransformScheme(
240+
type=self.transform_type,
241+
randomize=self.randomize,
242+
requires_grad=self.learnable,
243+
precision=self.precision,
244+
head_dim=head_dim,
245+
apply=[
246+
TransformArgs(
247+
targets=[self.mappings.attn],
248+
location="q_attn",
249+
),
250+
TransformArgs(
251+
targets=[self.mappings.attn],
252+
location="k_cache",
253+
),
254+
],
241255
)
242256

243257
def _create_r4_scheme(self) -> TransformScheme:

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SpinQuantMapping(BaseModel):
1414
layers (https://arxiv.org/pdf/2405.16406 Fig. 1).
1515
1616
:param embedding: name or regex of embedding layer
17+
:param attn: name or regex of attention block in decoder layer
1718
:param attn_q: name or regex of q_proj layer in attention block
1819
:param attn_k: name or regex of k_proj layer in attention block
1920
:param attn_v: name or regex of v_proj layer in attention block
@@ -29,6 +30,7 @@ class SpinQuantMapping(BaseModel):
2930

3031
embedding: str
3132

33+
attn: str
3234
attn_q: str
3335
attn_k: str
3436
attn_v: str
@@ -50,6 +52,7 @@ def cast_to_list(cls, value):
5052

5153
_default_mappings = SpinQuantMapping(
5254
embedding="re:.*embed_tokens$",
55+
attn="re:.*self_attn$",
5356
attn_q="re:.*q_proj$",
5457
attn_k="re:.*k_proj$",
5558
attn_v="re:.*v_proj$",

0 commit comments

Comments
 (0)