Skip to content

Commit

Permalink
[Bugfix] Make torch registration of punica ops optional (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm authored Aug 28, 2024
1 parent 737129b commit ab365d1
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 21 deletions.
9 changes: 6 additions & 3 deletions vllm/lora/ops/bgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def _bgmv_expand(
return


bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand = _bgmv_expand
9 changes: 6 additions & 3 deletions vllm/lora/ops/bgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def _bgmv_expand_slice(
return


bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice
9 changes: 6 additions & 3 deletions vllm/lora/ops/bgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def _bgmv_shrink(
return


bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_shrink = _bgmv_shrink
9 changes: 6 additions & 3 deletions vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def _sgmv_expand(
return


sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand = _sgmv_expand
9 changes: 6 additions & 3 deletions vllm/lora/ops/sgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def _sgmv_expand_slice(
return


sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice
9 changes: 6 additions & 3 deletions vllm/lora/ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def _sgmv_shrink(
return


sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_shrink = _sgmv_shrink
4 changes: 1 addition & 3 deletions vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
import torch

from vllm.triton_utils import HAS_TRITON
from vllm.utils import is_xpu

# FIXME: xpu path doesn't support torch.library.custom_op
if HAS_TRITON and not is_xpu():
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
Expand Down

0 comments on commit ab365d1

Please sign in to comment.