diff --git a/vllm_ascend/lora/punica_wrapper/lora_ops.py b/vllm_ascend/lora/punica_wrapper/lora_ops.py index a8ff21d748a..e8bf8ad9717 100644 --- a/vllm_ascend/lora/punica_wrapper/lora_ops.py +++ b/vllm_ascend/lora/punica_wrapper/lora_ops.py @@ -52,14 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - return torch.ops._C.bgmv_expand( - inputs, - lora_b_weights, - lora_indices_tensor, - output_tensor, - slice_offset, - slice_size - ) + return torch.ops._C.bgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, output_tensor, + slice_offset, slice_size) def sgmv_shrink( @@ -74,8 +69,9 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor, - seq_len_tensor, output_tensor, scaling) + return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, scaling) def sgmv_expand(inputs: torch.Tensor, @@ -111,12 +107,6 @@ def sgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand( - inputs, - lora_b_weights, - lora_indices_tensor, - seq_len_tensor, - output_tensor, - slice_offset, - slice_size - ) + return torch.ops._C.sgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, slice_offset, slice_size) diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py index f292e614239..47c775887de 100644 --- a/vllm_ascend/meta_registration.py +++ b/vllm_ascend/meta_registration.py @@ -80,23 +80,18 @@ def get_masked_input_and_mask_meta(input: torch.Tensor, return masked_input, mask -def bgmv_expand_meta(x: torch.Tensor, - weight: torch.Tensor, - indices: torch.Tensor, - y: torch.Tensor, - slice_offset: int, - slice_size: int): + +def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, + indices: torch.Tensor, y: torch.Tensor, slice_offset: int, + slice_size: int): y_out = torch.empty_like(y) return y_out -def sgmv_expand_meta(x: torch.Tensor, - weight: torch.Tensor, - lora_indices: torch.Tensor, - seq_len: torch.Tensor, - y: torch.Tensor, - slice_offset: int, - slice_size: int): + +def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, + lora_indices: torch.Tensor, seq_len: torch.Tensor, + y: torch.Tensor, slice_offset: int, slice_size: int): y_out = torch.empty_like(y) return y_out