Skip to content

Commit 9f0e69b

Browse files
[CI/Build] Fix mypy errors (#6968)
1 parent f230cc2 commit 9f0e69b

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

vllm/_custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
import functools
3-
from typing import List, Optional, Tuple, Type
3+
from typing import List, Optional, Tuple, Union
44

55
import torch
66

@@ -336,7 +336,7 @@ def scaled_fp8_quant(
336336
"""
337337
# This code assumes batch_dim and num_tokens are flattened
338338
assert (input.ndim == 2)
339-
shape = input.shape
339+
shape: Union[Tuple[int, int], torch.Size] = input.shape
340340
if num_token_padding:
341341
shape = (max(num_token_padding, input.shape[0]), shape[1])
342342
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)

vllm/multimodal/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase):
5353
"""
5454

5555
@staticmethod
56-
def _try_concat(
57-
tensors: List[NestedTensors],
58-
) -> Union[GenericSequence[NestedTensors], NestedTensors]:
56+
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
5957
"""
6058
If each input tensor in the batch has the same shape, return a single
6159
batched tensor; otherwise, return a list of :class:`NestedTensors` with
@@ -105,7 +103,7 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
105103
return {
106104
k: MultiModalInputs._try_concat(item_list)
107105
for k, item_list in item_lists.items()
108-
} # type: ignore
106+
}
109107

110108
@staticmethod
111109
def as_kwargs(

0 commit comments

Comments
 (0)