Skip to content

Commit

Permalink
Fix broken custom op for PyTorch < 2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasvanderwerff committed Oct 18, 2024
1 parent 11198e2 commit 1543e4f
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions torchao/prototype/spinquant/hadamard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

import torch

from torchao.ops import register_custom_op
from torchao.ops import lib
from torchao.prototype.spinquant._hadamard_matrices import get_had172, get_had156, get_had140, get_had108, get_had60, get_had52, get_had36, get_had28, get_had44, get_had40, get_had20, get_had12
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

try:
from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform
Expand All @@ -33,7 +34,26 @@ def matmul_hadU(X, hadK, K):
return matmul_hadU_slow(X, hadK, K)


@torch.library.custom_op("torchao::hadamard_transform", mutates_args=())
def register_custom_op_impl(name):
def decorator(func):
if TORCH_VERSION_AT_LEAST_2_4:
return torch.library.custom_op(f"{name}", mutates_args=())(func)
else:
lib.define("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor")
return torch.library.impl(f"{name}", "cuda")(func)
return decorator


def register_custom_op_abstract(name):
def decorator(func):
if TORCH_VERSION_AT_LEAST_2_4:
return torch.library.register_fake(f"{name}")(func)
else:
return torch.library.impl_abstract(f"{name}")(func)
return decorator


@register_custom_op_impl("torchao::hadamard_transform")
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
"""
Arguments:
Expand All @@ -51,7 +71,7 @@ def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
return _fast_hadamard_transform(x, scale)


@register_custom_op("torchao::hadamard_transform")
@register_custom_op_abstract("torchao::hadamard_transform")
def _(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
torch._check(x.dim() >= 1, lambda: f"input should be at least a 1D tensor, got {x.dim()}D")
return torch.empty_like(x)
Expand Down Expand Up @@ -169,9 +189,9 @@ def matmul_hadU_slow(X, hadK, K):
def matmul_hadU_fast(X, hadK, K):
n = X.shape[-1]
if K == 1:
return hadamard_transform(X.contiguous()) / torch.tensor(n).sqrt()
return torch.ops.torchao.hadamard_transform.default(X.contiguous()) / torch.tensor(n).sqrt()
input = X.view(-1, K, n // K)
input = hadamard_transform(input.contiguous()) / torch.tensor(n).sqrt()
input = torch.ops.torchao.hadamard_transform.default(input.contiguous()) / torch.tensor(n).sqrt()
input = hadK.to(input.device).to(input.dtype) @ input
return input.reshape(X.shape)

Expand Down

0 comments on commit 1543e4f

Please sign in to comment.