|
18 | 18 | from accelerate import init_empty_weights
|
19 | 19 |
|
20 | 20 | from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
21 |
| -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM |
| 21 | +from text_generation_server.utils.import_utils import ( |
| 22 | + IS_CUDA_SYSTEM, |
| 23 | + IS_ROCM_SYSTEM, |
| 24 | + IS_XPU_SYSTEM, |
| 25 | +) |
22 | 26 | from text_generation_server.utils.log import log_once
|
23 | 27 |
|
| 28 | +if IS_XPU_SYSTEM: |
| 29 | + import intel_extension_for_pytorch as ipex |
| 30 | + |
24 | 31 | HAS_AWQ = True
|
25 | 32 | try:
|
26 | 33 | from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
@@ -646,7 +653,13 @@ def forward(self, hidden_states, residual=None):
|
646 | 653 | if residual is not None:
|
647 | 654 | hidden_states += residual
|
648 | 655 | residual = hidden_states
|
649 |
| - out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) |
| 656 | + out = ipex.llm.modules.FastLayerNorm.apply( |
| 657 | + hidden_states, |
| 658 | + self.normalized_shape, |
| 659 | + self.eps, |
| 660 | + self.weight, |
| 661 | + self.bias, |
| 662 | + ) |
650 | 663 | return out, residual
|
651 | 664 | elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
652 | 665 | if residual is not None:
|
@@ -698,8 +711,11 @@ def forward(self, hidden_states, residual=None):
|
698 | 711 | if residual is not None:
|
699 | 712 | hidden_states += residual
|
700 | 713 | residual = hidden_states
|
701 |
| - out = torch.ops.torch_ipex.rms_norm( |
702 |
| - hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon |
| 714 | + out = ipex.llm.modules.RMSNorm.apply( |
| 715 | + hidden_states, |
| 716 | + [hidden_states.size(-1)], |
| 717 | + self.weight, |
| 718 | + self.variance_epsilon, |
703 | 719 | )
|
704 | 720 | return out[0], residual
|
705 | 721 | elif hidden_states.shape[-1] > 8192:
|
@@ -829,15 +845,14 @@ def forward(
|
829 | 845 | # Inplace operation, updating query and key.
|
830 | 846 | pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
831 | 847 | elif IS_XPU_SYSTEM:
|
832 |
| - sin = sin.expand(query.shape) |
833 |
| - cos = cos.expand(query.shape) |
834 |
| - torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key) |
| 848 | + ipex.llm.modules.RotaryEmbedding.apply( |
| 849 | + query, key, sin, cos, query.size(-1), True |
| 850 | + ) |
835 | 851 | else:
|
836 | 852 | raise ValueError(
|
837 | 853 | "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
838 | 854 | )
|
839 | 855 |
|
840 |
| - |
841 | 856 | @classmethod
|
842 | 857 | def static(cls, config, dim, base, device):
|
843 | 858 | inv_freq = _create_inv_freq(dim, base, device)
|
@@ -953,8 +968,6 @@ def get_cos_sin(
|
953 | 968 | cos = torch.index_select(self._cos_cached, 0, position_ids)
|
954 | 969 | sin = torch.index_select(self._sin_cached, 0, position_ids)
|
955 | 970 |
|
956 |
| - if IS_XPU_SYSTEM: |
957 |
| - return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2) |
958 | 971 | # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
959 | 972 | return cos.unsqueeze(1), sin.unsqueeze(1)
|
960 | 973 |
|
|
0 commit comments