File tree Expand file tree Collapse file tree 2 files changed +4
-6
lines changed Expand file tree Collapse file tree 2 files changed +4
-6
lines changed Original file line number Diff line number Diff line change 11import contextlib
22import functools
3- from typing import List , Optional , Tuple , Type
3+ from typing import List , Optional , Tuple , Union
44
55import torch
66
@@ -336,7 +336,7 @@ def scaled_fp8_quant(
336336 """
337337 # This code assumes batch_dim and num_tokens are flattened
338338 assert (input .ndim == 2 )
339- shape = input .shape
339+ shape : Union [ Tuple [ int , int ], torch . Size ] = input .shape
340340 if num_token_padding :
341341 shape = (max (num_token_padding , input .shape [0 ]), shape [1 ])
342342 output = torch .empty (shape , device = input .device , dtype = torch .float8_e4m3fn )
Original file line number Diff line number Diff line change @@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase):
5353 """
5454
5555 @staticmethod
56- def _try_concat (
57- tensors : List [NestedTensors ],
58- ) -> Union [GenericSequence [NestedTensors ], NestedTensors ]:
56+ def _try_concat (tensors : List [NestedTensors ]) -> BatchedTensors :
5957 """
6058 If each input tensor in the batch has the same shape, return a single
6159 batched tensor; otherwise, return a list of :class:`NestedTensors` with
@@ -105,7 +103,7 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
105103 return {
106104 k : MultiModalInputs ._try_concat (item_list )
107105 for k , item_list in item_lists .items ()
108- } # type: ignore
106+ }
109107
110108 @staticmethod
111109 def as_kwargs (
You can’t perform that action at this time.
0 commit comments