11
11
12
12
import torch
13
13
14
+ from torchao .prototype .spinquant ._hadamard_matrices import get_had172 , get_had156 , get_had140 , get_had108 , get_had60 , get_had52 , get_had36 , get_had28 , get_had44 , get_had40 , get_had20 , get_had12
15
+
14
16
try :
15
- """
16
- Note: fast_hadamard_transform package is required for CUDA support.
17
-
18
- To install the fast_hadamard_transform package, run the following:
19
- ```
20
- pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
21
- ```
22
- """
23
17
from fast_hadamard_transform import hadamard_transform
24
- except :
25
- pass
26
18
27
- from torchao .prototype .spinquant ._hadamard_matrices import get_had172 , get_had156 , get_had140 , get_had108 , get_had60 , get_had52 , get_had36 , get_had28 , get_had44 , get_had40 , get_had20 , get_had12
19
+ def matmul_hadU (X , hadK , K ):
20
+ if X .is_cuda :
21
+ return matmul_hadU_fast (X , hadK , K )
22
+ else :
23
+ return matmul_hadU_slow (X , hadK , K )
24
+
25
+ except ImportError :
26
+
27
+ print ("NOTE: Using slow Hadamard transform for SpinQuant. "
28
+ "For better performance on GPU, install `fast_hadamard_transform`: "
29
+ "`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`" )
30
+
31
+ def matmul_hadU (X , hadK , K ):
32
+ return matmul_hadU_slow (X , hadK , K )
28
33
29
34
30
35
class HadamardTransform (torch .autograd .Function ):
@@ -113,14 +118,7 @@ def get_hadK(n, transpose=False):
113
118
return hadK , K
114
119
115
120
116
- def matmul_hadU (X , hadK , K ):
117
- if X .device == torch .device ("cpu" ):
118
- return matmul_hadU_cpu (X , hadK , K )
119
- else :
120
- return matmul_hadU_cuda (X , hadK , K )
121
-
122
-
123
- def matmul_hadU_cpu (X , hadK , K ):
121
+ def matmul_hadU_slow (X , hadK , K ):
124
122
n = X .shape [- 1 ]
125
123
input = X .clone ().view (- 1 , n , 1 )
126
124
output = input .clone ()
@@ -143,7 +141,7 @@ def matmul_hadU_cpu(X, hadK, K):
143
141
return input .view (X .shape ) / torch .tensor (n ).sqrt ()
144
142
145
143
146
- def matmul_hadU_cuda (X , hadK , K ):
144
+ def matmul_hadU_fast (X , hadK , K ):
147
145
n = X .shape [- 1 ]
148
146
if K == 1 :
149
147
return HadamardTransform .apply (X .contiguous ()) / torch .tensor (n ).sqrt ()
@@ -161,14 +159,14 @@ def random_hadamard_matrix(size, device, seed=0):
161
159
Q = Q * 2 - 1
162
160
Q = torch .diag (Q )
163
161
hadK , K = get_hadK (size )
164
- return matmul_hadU_cpu (Q , hadK , K ).to (device )
162
+ return matmul_hadU_slow (Q , hadK , K ).to (device )
165
163
166
164
167
165
def hadamard_matrix (size , device ):
168
166
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
169
167
Q = torch .eye (size )
170
168
hadK , K = get_hadK (size )
171
- return matmul_hadU_cpu (Q , hadK , K ).to (device )
169
+ return matmul_hadU_slow (Q , hadK , K ).to (device )
172
170
173
171
174
172
def apply_exact_had_to_linear (module , had_dim = - 1 , output = False , R2 = None ):
@@ -180,22 +178,20 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
180
178
181
179
W = module .weight .data
182
180
dtype_orig = W .dtype
183
- device_orig = W .device
184
- device = "cuda" if torch .cuda .is_available () else "cpu"
185
- W = W .float ().to (device = device )
181
+ W = W .float ()
186
182
187
183
if had_dim == - 1 :
188
184
if output :
189
185
had_K , K = get_hadK (out_features )
190
- W = matmul_hadU (W .t (), had_K , K ).t ()
186
+ W = matmul_hadU (W .t (), had_K . to ( W . device ) , K ).t ()
191
187
else :
192
188
had_K , K = get_hadK (in_features )
193
- W = matmul_hadU (W , had_K , K )
189
+ W = matmul_hadU (W , had_K . to ( W . device ) , K )
194
190
else :
195
191
if R2 is not None :
196
192
hadK = R2 .to (torch .float64 )
197
193
else :
198
- hadK = hadamard_matrix (had_dim , device ).to (torch .float64 )
194
+ hadK = hadamard_matrix (had_dim , W . device ).to (torch .float64 )
199
195
200
196
if output :
201
197
W = W .t ()
@@ -208,4 +204,4 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
208
204
if output :
209
205
W = W .t ()
210
206
211
- module .weight .data = W .to (device = device_orig , dtype = dtype_orig )
207
+ module .weight .data = W .to (dtype = dtype_orig )
0 commit comments