Skip to content

Commit 61bae75

Browse files
committed
fix lint
Signed-off-by: realliujiaxu <[email protected]>
1 parent e285eb9 commit 61bae75

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

vllm_ascend/ops/linear.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
limitations under the License.
1616
"""
1717

18-
from typing import Optional, Union
18+
from typing import Optional, Tuple, Union
1919

2020
import torch
2121
import torch.distributed as dist
@@ -69,6 +69,9 @@ def after_create_weights_hook(self):
6969
self.return_bias = self.layer.return_bias
7070
self.quant_method = self.layer.quant_method
7171

72+
def apply(self, input_):
73+
raise NotImplementedError
74+
7275

7376
class CustomColumnParallel(CustomTensorParallelBase):
7477

@@ -188,18 +191,25 @@ def apply(
188191
return output, output_bias
189192

190193

191-
def get_custom_tp_group_column(disable_tp, prefix, layer):
194+
def get_custom_tp_group_column(
195+
disable_tp, prefix, layer
196+
) -> Tuple[
197+
Optional[Union[MLPCustomColumnParallel, DenseOptimMergedColumnParallel,
198+
DenseOptimQKVParallelLinear]], int, int]:
192199
if disable_tp:
193200
return None, 0, 1
194201

202+
custom_tp_group: Optional[Union[
203+
MLPCustomColumnParallel,
204+
DenseOptimMergedColumnParallel,
205+
DenseOptimQKVParallelLinear,
206+
]] = None
195207
if "gate_up_proj" in prefix and mlp_tp_enable():
196208
custom_tp_group = MLPCustomColumnParallel(layer)
197209
elif "gate_up_proj" in prefix and dense_optim_enable():
198210
custom_tp_group = DenseOptimMergedColumnParallel(layer)
199211
elif dense_optim_enable():
200212
custom_tp_group = DenseOptimQKVParallelLinear(layer, prefix)
201-
else:
202-
custom_tp_group = None
203213

204214
if custom_tp_group is not None:
205215
return custom_tp_group, custom_tp_group.tp_rank, custom_tp_group.tp_size
@@ -329,6 +339,7 @@ def apply(
329339
self.hcomm_info,
330340
bias=bias_)
331341
else:
342+
assert self.quant_method is not None
332343
output = self.quant_method.apply(self.layer,
333344
input_parallel,
334345
bias=bias_)
@@ -407,10 +418,18 @@ def after_create_weights_hook(self):
407418
self.reduce_results = self.layer.reduce_results
408419

409420

410-
def get_custom_tp_group_row(disable_tp, prefix, layer):
421+
def get_custom_tp_group_row(
422+
disable_tp, prefix, layer
423+
) -> Tuple[Optional[Union[MLPCustomRowParallel, OProjCustomRowParallel,
424+
MatmulAllreduceCustomRowParallel,
425+
DenseOptimCustomRowParallel]], int, int]:
411426
if disable_tp:
412427
return None, 0, 1
413428

429+
custom_tp_group: Optional[Union[MLPCustomRowParallel,
430+
OProjCustomRowParallel,
431+
MatmulAllreduceCustomRowParallel,
432+
DenseOptimCustomRowParallel]] = None
414433
if prefix.find("down_proj") != -1 and mlp_tp_enable():
415434
custom_tp_group = MLPCustomRowParallel(layer)
416435
elif prefix.find("o_proj") != -1 and oproj_tp_enable():
@@ -419,8 +438,6 @@ def get_custom_tp_group_row(disable_tp, prefix, layer):
419438
custom_tp_group = MatmulAllreduceCustomRowParallel(layer)
420439
elif dense_optim_enable():
421440
custom_tp_group = DenseOptimCustomRowParallel(layer, prefix)
422-
else:
423-
custom_tp_group = None
424441

425442
if custom_tp_group is not None:
426443
return custom_tp_group, custom_tp_group.tp_rank, custom_tp_group.tp_size

0 commit comments

Comments
 (0)