Skip to content

Commit b3989cf

Browse files
committed
align to ipex llm ops
Signed-off-by: Wang, Yi A <[email protected]>
1 parent 3a79d15 commit b3989cf

File tree

3 files changed

+47
-19
lines changed

3 files changed

+47
-19
lines changed

server/text_generation_server/utils/flash_attn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
from loguru import logger
55
import math
66

7-
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
7+
from text_generation_server.utils.import_utils import (
8+
IS_CUDA_SYSTEM,
9+
IS_ROCM_SYSTEM,
10+
IS_XPU_SYSTEM,
11+
)
812

913
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
1014
raise ImportError("`USE_FLASH_ATTENTION` is false.")
1115
HAS_FLASH_ATTN = True
1216
HAS_FLASH_ATTN_V2_CUDA = False
1317
HAS_FLASH_ATTN_V2_ROCM = False
1418

19+
if IS_XPU_SYSTEM:
20+
import intel_extension_for_pytorch as ipex
21+
1522
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
1623
if not torch.cuda.is_available():
1724
raise ImportError("CUDA is not available")
@@ -90,7 +97,7 @@ def attention(
9097
raise ValueError(
9198
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
9299
)
93-
return torch.xpu.varlen_fwd(
100+
return ipex.llm.modules.VarlenAttention.apply(
94101
q,
95102
k,
96103
v,
@@ -104,10 +111,9 @@ def attention(
104111
False,
105112
True,
106113
False,
107-
None
114+
None,
108115
)
109116

110-
111117
if HAS_FLASH_ATTN_V2_CUDA:
112118
return flash_attn_2_cuda.varlen_fwd(
113119
q,

server/text_generation_server/utils/layers.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,16 @@
1818
from accelerate import init_empty_weights
1919

2020
from text_generation_server.utils.gptq.quant_linear import QuantLinear
21-
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
21+
from text_generation_server.utils.import_utils import (
22+
IS_CUDA_SYSTEM,
23+
IS_ROCM_SYSTEM,
24+
IS_XPU_SYSTEM,
25+
)
2226
from text_generation_server.utils.log import log_once
2327

28+
if IS_XPU_SYSTEM:
29+
import intel_extension_for_pytorch as ipex
30+
2431
HAS_AWQ = True
2532
try:
2633
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
@@ -646,7 +653,13 @@ def forward(self, hidden_states, residual=None):
646653
if residual is not None:
647654
hidden_states += residual
648655
residual = hidden_states
649-
out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
656+
out = ipex.llm.modules.FastLayerNorm.apply(
657+
hidden_states,
658+
self.normalized_shape,
659+
self.eps,
660+
self.weight,
661+
self.bias,
662+
)
650663
return out, residual
651664
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
652665
if residual is not None:
@@ -698,8 +711,11 @@ def forward(self, hidden_states, residual=None):
698711
if residual is not None:
699712
hidden_states += residual
700713
residual = hidden_states
701-
out = torch.ops.torch_ipex.rms_norm(
702-
hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon
714+
out = ipex.llm.modules.RMSNorm.apply(
715+
hidden_states,
716+
[hidden_states.size(-1)],
717+
self.weight,
718+
self.variance_epsilon,
703719
)
704720
return out[0], residual
705721
elif hidden_states.shape[-1] > 8192:
@@ -829,15 +845,14 @@ def forward(
829845
# Inplace operation, updating query and key.
830846
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
831847
elif IS_XPU_SYSTEM:
832-
sin = sin.expand(query.shape)
833-
cos = cos.expand(query.shape)
834-
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key)
848+
ipex.llm.modules.RotaryEmbedding.apply(
849+
query, key, sin, cos, query.size(-1), True
850+
)
835851
else:
836852
raise ValueError(
837853
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
838854
)
839855

840-
841856
@classmethod
842857
def static(cls, config, dim, base, device):
843858
inv_freq = _create_inv_freq(dim, base, device)
@@ -953,8 +968,6 @@ def get_cos_sin(
953968
cos = torch.index_select(self._cos_cached, 0, position_ids)
954969
sin = torch.index_select(self._sin_cached, 0, position_ids)
955970

956-
if IS_XPU_SYSTEM:
957-
return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2)
958971
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
959972
return cos.unsqueeze(1), sin.unsqueeze(1)
960973

server/text_generation_server/utils/paged_attention.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import torch
2-
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
2+
from text_generation_server.utils.import_utils import (
3+
IS_CUDA_SYSTEM,
4+
IS_ROCM_SYSTEM,
5+
IS_XPU_SYSTEM,
6+
)
7+
38
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
49
from vllm import cache_ops
510
from vllm import attention_ops
611

712
_PARTITION_SIZE = 512
813

14+
if IS_XPU_SYSTEM:
15+
import intel_extension_for_pytorch as ipex
916

1017

1118
def reshape_and_cache(
@@ -18,7 +25,9 @@ def reshape_and_cache(
1825
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
1926
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
2027
elif IS_XPU_SYSTEM:
21-
torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots)
28+
ipex.llm.modules.PagedAttention.reshape_and_cache(
29+
key, value, key_cache, value_cache, slots
30+
)
2231

2332

2433
def attention(
@@ -60,18 +69,18 @@ def attention(
6069
# to parallelize.
6170
if IS_XPU_SYSTEM:
6271
query = query.contiguous()
63-
return torch.xpu.IpexPaged_attention(
72+
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
6473
out,
6574
query,
6675
key_cache,
6776
value_cache,
6877
kv_head_mapping,
78+
softmax_scale,
6979
block_tables,
7080
input_lengths,
71-
softmax_scale,
7281
block_size,
7382
max_s,
74-
None
83+
None,
7584
)
7685

7786
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512

0 commit comments

Comments
 (0)