From 1543e4f9baf686a53d526b05780dfe14d0bfc999 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Fri, 18 Oct 2024 10:59:25 +0200 Subject: [PATCH] Fix broken custom op for PyTorch < 2.4 --- torchao/prototype/spinquant/hadamard_utils.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/spinquant/hadamard_utils.py b/torchao/prototype/spinquant/hadamard_utils.py index b38c666565..6e17a04de9 100644 --- a/torchao/prototype/spinquant/hadamard_utils.py +++ b/torchao/prototype/spinquant/hadamard_utils.py @@ -11,8 +11,9 @@ import torch -from torchao.ops import register_custom_op +from torchao.ops import lib from torchao.prototype.spinquant._hadamard_matrices import get_had172, get_had156, get_had140, get_had108, get_had60, get_had52, get_had36, get_had28, get_had44, get_had40, get_had20, get_had12 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 try: from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform @@ -33,7 +34,26 @@ def matmul_hadU(X, hadK, K): return matmul_hadU_slow(X, hadK, K) -@torch.library.custom_op("torchao::hadamard_transform", mutates_args=()) +def register_custom_op_impl(name): + def decorator(func): + if TORCH_VERSION_AT_LEAST_2_4: + return torch.library.custom_op(f"{name}", mutates_args=())(func) + else: + lib.define("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor") + return torch.library.impl(f"{name}", "cuda")(func) + return decorator + + +def register_custom_op_abstract(name): + def decorator(func): + if TORCH_VERSION_AT_LEAST_2_4: + return torch.library.register_fake(f"{name}")(func) + else: + return torch.library.impl_abstract(f"{name}")(func) + return decorator + + +@register_custom_op_impl("torchao::hadamard_transform") def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: """ Arguments: @@ -51,7 +71,7 @@ def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: return _fast_hadamard_transform(x, scale) -@register_custom_op("torchao::hadamard_transform") +@register_custom_op_abstract("torchao::hadamard_transform") def _(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: torch._check(x.dim() >= 1, lambda: f"input should be at least a 1D tensor, got {x.dim()}D") return torch.empty_like(x) @@ -169,9 +189,9 @@ def matmul_hadU_slow(X, hadK, K): def matmul_hadU_fast(X, hadK, K): n = X.shape[-1] if K == 1: - return hadamard_transform(X.contiguous()) / torch.tensor(n).sqrt() + return torch.ops.torchao.hadamard_transform.default(X.contiguous()) / torch.tensor(n).sqrt() input = X.view(-1, K, n // K) - input = hadamard_transform(input.contiguous()) / torch.tensor(n).sqrt() + input = torch.ops.torchao.hadamard_transform.default(input.contiguous()) / torch.tensor(n).sqrt() input = hadK.to(input.device).to(input.dtype) @ input return input.reshape(X.shape)