From b116d35e563f51a98213168d79ab8f73955644b1 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 28 Aug 2024 18:11:49 -0400 Subject: [PATCH] [Bugfix] Make torch registration of punica ops optional (#7970) Signed-off-by: Alvant --- vllm/lora/ops/bgmv_expand.py | 9 ++++++--- vllm/lora/ops/bgmv_expand_slice.py | 9 ++++++--- vllm/lora/ops/bgmv_shrink.py | 9 ++++++--- vllm/lora/ops/sgmv_expand.py | 9 ++++++--- vllm/lora/ops/sgmv_expand_slice.py | 9 ++++++--- vllm/lora/ops/sgmv_shrink.py | 9 ++++++--- vllm/lora/punica.py | 4 +--- 7 files changed, 37 insertions(+), 21 deletions(-) diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 0bbc1844ef455..619408b9315cf 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -160,6 +160,9 @@ def _bgmv_expand( return -bgmv_expand = torch.library.custom_op("lora::bgmv_expand", - _bgmv_expand, - mutates_args=["output_tensor"]) +try: + bgmv_expand = torch.library.custom_op("lora::bgmv_expand", + _bgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 87d7d9902a4c1..c16db233891a5 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -173,6 +173,9 @@ def _bgmv_expand_slice( return -bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", - _bgmv_expand_slice, - mutates_args=["output_tensor"]) +try: + bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", + _bgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index c979d758492db..0846ff36b1692 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -142,6 +142,9 @@ def _bgmv_shrink( return -bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", - _bgmv_shrink, - mutates_args=["output_tensor"]) +try: + bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", + _bgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 80a0b605b0fe2..c71332d8bdfb2 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -192,6 +192,9 @@ def _sgmv_expand( return -sgmv_expand = torch.library.custom_op("lora::sgmv_expand", - _sgmv_expand, - mutates_args=["output_tensor"]) +try: + sgmv_expand = torch.library.custom_op("lora::sgmv_expand", + _sgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 53237166a1c68..b4ae9a2acbb5c 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -205,6 +205,9 @@ def _sgmv_expand_slice( return -sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", - _sgmv_expand_slice, - mutates_args=["output_tensor"]) +try: + sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", + _sgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 51d2a09eee94b..c0791c260e915 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -189,6 +189,9 @@ def _sgmv_shrink( return -sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", - _sgmv_shrink, - mutates_args=["output_tensor"]) +try: + sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", + _sgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_shrink = _sgmv_shrink diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index d666fc293757b..6d5c834299961 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -10,10 +10,8 @@ import torch from vllm.triton_utils import HAS_TRITON -from vllm.utils import is_xpu -# FIXME: xpu path doesn't support torch.library.custom_op -if HAS_TRITON and not is_xpu(): +if HAS_TRITON: from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink