Skip to content

Commit fb3882f

Browse files
Do not fail if fast_hadamard_transform is not present
1 parent 49f035d commit fb3882f

File tree

2 files changed

+26
-37
lines changed

2 files changed

+26
-37
lines changed

test/prototype/test_spinquant.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
1-
import importlib
21
import pytest
32
import torch
43
from torchao._models.llama.model import Transformer
54
from torchao.prototype.spinquant import apply_spinquant
65

76

8-
def _is_package_available(pkg_name):
9-
return importlib.util.find_spec(pkg_name) is not None
10-
11-
127
def _init_model(name="7B", device="cpu", precision=torch.bfloat16):
138
model = Transformer.from_name(name)
149
model.to(device=device, dtype=precision)
1510
return model.eval()
1611

1712

18-
_AVAILABLE_DEVICES = ["cpu"]
19-
if torch.cuda.is_available() and _is_package_available("fast_hadamard_transform"):
20-
_AVAILABLE_DEVICES.append("cuda")
13+
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
2114

2215

2316
@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)

torchao/prototype/spinquant/hadamard_utils.py

+25-29
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,25 @@
1111

1212
import torch
1313

14+
from torchao.prototype.spinquant._hadamard_matrices import get_had172, get_had156, get_had140, get_had108, get_had60, get_had52, get_had36, get_had28, get_had44, get_had40, get_had20, get_had12
15+
1416
try:
15-
"""
16-
Note: fast_hadamard_transform package is required for CUDA support.
17-
18-
To install the fast_hadamard_transform package, run the following:
19-
```
20-
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
21-
```
22-
"""
2317
from fast_hadamard_transform import hadamard_transform
24-
except:
25-
pass
2618

27-
from torchao.prototype.spinquant._hadamard_matrices import get_had172, get_had156, get_had140, get_had108, get_had60, get_had52, get_had36, get_had28, get_had44, get_had40, get_had20, get_had12
19+
def matmul_hadU(X, hadK, K):
20+
if X.is_cuda:
21+
return matmul_hadU_fast(X, hadK, K)
22+
else:
23+
return matmul_hadU_slow(X, hadK, K)
24+
25+
except ImportError:
26+
27+
print("NOTE: Using slow Hadamard transform for SpinQuant. "
28+
"For better performance on GPU, install `fast_hadamard_transform`: "
29+
"`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`")
30+
31+
def matmul_hadU(X, hadK, K):
32+
return matmul_hadU_slow(X, hadK, K)
2833

2934

3035
class HadamardTransform(torch.autograd.Function):
@@ -113,14 +118,7 @@ def get_hadK(n, transpose=False):
113118
return hadK, K
114119

115120

116-
def matmul_hadU(X, hadK, K):
117-
if X.device == torch.device("cpu"):
118-
return matmul_hadU_cpu(X, hadK, K)
119-
else:
120-
return matmul_hadU_cuda(X, hadK, K)
121-
122-
123-
def matmul_hadU_cpu(X, hadK, K):
121+
def matmul_hadU_slow(X, hadK, K):
124122
n = X.shape[-1]
125123
input = X.clone().view(-1, n, 1)
126124
output = input.clone()
@@ -143,7 +141,7 @@ def matmul_hadU_cpu(X, hadK, K):
143141
return input.view(X.shape) / torch.tensor(n).sqrt()
144142

145143

146-
def matmul_hadU_cuda(X, hadK, K):
144+
def matmul_hadU_fast(X, hadK, K):
147145
n = X.shape[-1]
148146
if K == 1:
149147
return HadamardTransform.apply(X.contiguous()) / torch.tensor(n).sqrt()
@@ -161,14 +159,14 @@ def random_hadamard_matrix(size, device, seed=0):
161159
Q = Q * 2 - 1
162160
Q = torch.diag(Q)
163161
hadK, K = get_hadK(size)
164-
return matmul_hadU_cpu(Q, hadK, K).to(device)
162+
return matmul_hadU_slow(Q, hadK, K).to(device)
165163

166164

167165
def hadamard_matrix(size, device):
168166
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
169167
Q = torch.eye(size)
170168
hadK, K = get_hadK(size)
171-
return matmul_hadU_cpu(Q, hadK, K).to(device)
169+
return matmul_hadU_slow(Q, hadK, K).to(device)
172170

173171

174172
def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
@@ -180,22 +178,20 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
180178

181179
W = module.weight.data
182180
dtype_orig = W.dtype
183-
device_orig = W.device
184-
device = "cuda" if torch.cuda.is_available() else "cpu"
185-
W = W.float().to(device=device)
181+
W = W.float()
186182

187183
if had_dim == -1:
188184
if output:
189185
had_K, K = get_hadK(out_features)
190-
W = matmul_hadU(W.t(), had_K, K).t()
186+
W = matmul_hadU(W.t(), had_K.to(W.device), K).t()
191187
else:
192188
had_K, K = get_hadK(in_features)
193-
W = matmul_hadU(W, had_K, K)
189+
W = matmul_hadU(W, had_K.to(W.device), K)
194190
else:
195191
if R2 is not None:
196192
hadK = R2.to(torch.float64)
197193
else:
198-
hadK = hadamard_matrix(had_dim, device).to(torch.float64)
194+
hadK = hadamard_matrix(had_dim, W.device).to(torch.float64)
199195

200196
if output:
201197
W = W.t()
@@ -208,4 +204,4 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
208204
if output:
209205
W = W.t()
210206

211-
module.weight.data = W.to(device=device_orig, dtype=dtype_orig)
207+
module.weight.data = W.to(dtype=dtype_orig)

0 commit comments

Comments
 (0)