From 488e2b837a6ec5ee95b98599e90fc44fdc03b70f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Jul 2024 20:18:02 -0700 Subject: [PATCH 01/10] update gpt2 --- vllm/model_executor/models/gpt2.py | 33 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index be19f4ba8c71..93045b7cccc1 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -291,19 +291,22 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if not name.startswith("transformer."): name = "transformer." + name - try: - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - except KeyError: + + if name not in params_dict: + # in pipeline parallelism, we may have layers that are not + # present on this rank continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From e7a2842e9f3bea638d88f04d121ae8813a1e14b9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Jul 2024 20:19:51 -0700 Subject: [PATCH 02/10] update llama --- vllm/model_executor/models/llama.py | 33 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 77edcd7402db..29596efa1233 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -455,12 +455,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - try: - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - except KeyError: - pass + + if name not in params_dict: + # in pipeline parallelism, we may have layers that are not + # present on this rank + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break else: # Skip loading extra bias for GPTQ models. @@ -479,13 +483,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name - try: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - except KeyError: - pass + + if name not in params_dict: + # in pipeline parallelism, we may have layers that are not + # present on this rank + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should From 8b725b5ae9fb5cc27dd01f15a82c4026ee17ae1f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Jul 2024 20:35:30 -0700 Subject: [PATCH 03/10] remove duplicate code --- vllm/model_executor/models/gpt2.py | 15 +++------------ vllm/model_executor/models/llama.py | 22 ++++++---------------- vllm/utils.py | 17 +++++++++++++++++ 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 93045b7cccc1..7519455c674b 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -27,7 +27,6 @@ from vllm.config import CacheConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_world_size) -from vllm.distributed.utils import get_pp_indices from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -41,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import make_layers class GPT2Attention(nn.Module): @@ -183,18 +183,9 @@ def __init__( self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.start_layer, self.end_layer = get_pp_indices( + self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - get_pp_group().rank_in_group, - get_pp_group().world_size) - self.h = nn.ModuleList( - [nn.Identity() for _ in range(self.start_layer)] + [ - GPT2Block(config, cache_config, quant_config) - for _ in range(self.start_layer, self.end_layer) - ] + [ - nn.Identity() - for _ in range(self.end_layer, config.num_hidden_layers) - ]) + lambda: GPT2Block(config, cache_config, quant_config)) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 29596efa1233..fa47310f2b1a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,8 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_pp_group, get_pp_indices, - get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -48,7 +47,7 @@ default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import is_hip, print_warning_once +from vllm.utils import is_hip, make_layers, print_warning_once from .interfaces import SupportsLoRA @@ -262,20 +261,11 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.start_layer, self.end_layer = get_pp_indices( + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - get_pp_group().rank_in_group, - get_pp_group().world_size) - self.layers = nn.ModuleList( - [nn.Identity() for _ in range(self.start_layer)] + [ - LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config) - for _ in range(self.start_layer, self.end_layer) - ] + [ - nn.Identity() - for _ in range(self.end_layer, config.num_hidden_layers) - ]) + lambda: LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config)) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: diff --git a/vllm/utils.py b/vllm/utils.py index 8be1528230b5..776433375f8c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -939,3 +939,20 @@ def parse_args(self, args=None, namespace=None): processed_args.append(arg) return super().parse_args(processed_args, namespace) + + +def make_layers( + num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] +) -> Tuple[int, int, torch.nn.ModuleList]: + """Make a list of layers with the given layer function, taking + pipeline parallelism into account. + """ + from vllm.distributed.utils import get_pp_group, get_pp_indices + start_layer, end_layer = get_pp_indices(num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + modules = torch.nn.ModuleList( + [torch.nn.Identity() for _ in range(start_layer)] + + [layer_fn() for _ in range(start_layer, end_layer)] + + [torch.nn.Identity() for _ in range(end_layer, num_hidden_layers)]) + return start_layer, end_layer, modules From 1945e619a13f50553187492046978dc71a3f608c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Jul 2024 21:28:28 -0700 Subject: [PATCH 04/10] fix import --- vllm/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 776433375f8c..6ce86eb7a579 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -947,7 +947,8 @@ def make_layers( """Make a list of layers with the given layer function, taking pipeline parallelism into account. """ - from vllm.distributed.utils import get_pp_group, get_pp_indices + from vllm.distributed.parallel_state import get_pp_group + from vllm.distributed.utils import get_pp_indices start_layer, end_layer = get_pp_indices(num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) From 347399e2e62693f850fffb130565caccedcab7df Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Jul 2024 10:35:56 -0700 Subject: [PATCH 05/10] move to models/ --- vllm/model_executor/models/gpt2.py | 3 ++- vllm/model_executor/models/llama.py | 3 ++- vllm/model_executor/models/utils.py | 20 ++++++++++++++++++++ vllm/utils.py | 18 ------------------ 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 7519455c674b..9d33d35c379d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -40,7 +40,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import make_layers + +from .utils import make_layers class GPT2Attention(nn.Module): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fa47310f2b1a..1a6d433823da 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -47,9 +47,10 @@ default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import is_hip, make_layers, print_warning_once +from vllm.utils import is_hip, print_warning_once from .interfaces import SupportsLoRA +from .utils import make_layers class LlamaMLP(nn.Module): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ef2562b073e6..5aafa432abb7 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,3 +1,5 @@ +from typing import Callable, Tuple + import torch from vllm.multimodal import BatchedTensors @@ -39,3 +41,21 @@ def merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds[mask] = torch.cat(vision_embeddings) return inputs_embeds + + +def make_layers( + num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] +) -> Tuple[int, int, torch.nn.ModuleList]: + """Make a list of layers with the given layer function, taking + pipeline parallelism into account. + """ + from vllm.distributed.parallel_state import get_pp_group + from vllm.distributed.utils import get_pp_indices + start_layer, end_layer = get_pp_indices(num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + modules = torch.nn.ModuleList( + [torch.nn.Identity() for _ in range(start_layer)] + + [layer_fn() for _ in range(start_layer, end_layer)] + + [torch.nn.Identity() for _ in range(end_layer, num_hidden_layers)]) + return start_layer, end_layer, modules diff --git a/vllm/utils.py b/vllm/utils.py index 6ce86eb7a579..8be1528230b5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -939,21 +939,3 @@ def parse_args(self, args=None, namespace=None): processed_args.append(arg) return super().parse_args(processed_args, namespace) - - -def make_layers( - num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] -) -> Tuple[int, int, torch.nn.ModuleList]: - """Make a list of layers with the given layer function, taking - pipeline parallelism into account. - """ - from vllm.distributed.parallel_state import get_pp_group - from vllm.distributed.utils import get_pp_indices - start_layer, end_layer = get_pp_indices(num_hidden_layers, - get_pp_group().rank_in_group, - get_pp_group().world_size) - modules = torch.nn.ModuleList( - [torch.nn.Identity() for _ in range(start_layer)] + - [layer_fn() for _ in range(start_layer, end_layer)] + - [torch.nn.Identity() for _ in range(end_layer, num_hidden_layers)]) - return start_layer, end_layer, modules From 61fa24286bd7de5468775f8b2de9360f00991cf5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Jul 2024 10:56:38 -0700 Subject: [PATCH 06/10] check names --- vllm/model_executor/models/gpt2.py | 6 ++--- vllm/model_executor/models/llama.py | 10 +++------ vllm/model_executor/models/utils.py | 34 ++++++++++++++++++++++++++--- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 9d33d35c379d..d309a2b27f5d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -41,7 +41,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from .utils import make_layers +from .utils import is_pp_missing_parameter, make_layers class GPT2Attention(nn.Module): @@ -284,9 +284,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if not name.startswith("transformer."): name = "transformer." + name - if name not in params_dict: - # in pipeline parallelism, we may have layers that are not - # present on this rank + if is_pp_missing_parameter(name, self): continue param = params_dict[name] diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1a6d433823da..a777d1fbfa80 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -50,7 +50,7 @@ from vllm.utils import is_hip, print_warning_once from .interfaces import SupportsLoRA -from .utils import make_layers +from .utils import is_pp_missing_parameter, make_layers class LlamaMLP(nn.Module): @@ -447,9 +447,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if name not in params_dict: - # in pipeline parallelism, we may have layers that are not - # present on this rank + if is_pp_missing_parameter(name, self): continue param = params_dict[name] @@ -475,9 +473,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: name = remapped_kv_scale_name - if name not in params_dict: - # in pipeline parallelism, we may have layers that are not - # present on this rank + if is_pp_missing_parameter(name, self): continue param = params_dict[name] diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5aafa432abb7..a60e8b376af5 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,5 @@ -from typing import Callable, Tuple +from functools import lru_cache +from typing import Callable, List, Tuple import torch @@ -43,6 +44,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor, return inputs_embeds +class PPMissingLayer(torch.nn.Identity): + """ + A placeholder layer for missing layers in a pipeline parallel model. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + def make_layers( num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] ) -> Tuple[int, int, torch.nn.ModuleList]: @@ -55,7 +65,25 @@ def make_layers( get_pp_group().rank_in_group, get_pp_group().world_size) modules = torch.nn.ModuleList( - [torch.nn.Identity() for _ in range(start_layer)] + + [PPMissingLayer() for _ in range(start_layer)] + [layer_fn() for _ in range(start_layer, end_layer)] + - [torch.nn.Identity() for _ in range(end_layer, num_hidden_layers)]) + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules + + +@lru_cache +def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: + """Get the names of the missing layers in a pipeline parallel model.""" + missing_layer_names = [] + for name, module in model.named_modules(): + if isinstance(module, PPMissingLayer): + missing_layer_names.append(name) + return missing_layer_names + + +def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: + """Check if a parameter is missing in a pipeline parallel model.""" + for missing_layer_name in get_pp_missing_layer_names(model): + if name.startswith(missing_layer_name): + return True + return False From 78d7d4b9b566777f0b8ef79d00197eb7b86a2749 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Jul 2024 22:58:05 -0700 Subject: [PATCH 07/10] unify tests --- .buildkite/test-pipeline.yaml | 4 +--- tests/basic_correctness/test_basic_correctness.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c8f53224b1dc..54fd76b1e81c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -46,9 +46,7 @@ steps: fast_check: true commands: - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py + - pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index d3e74a4f834a..876fd37c1c19 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -28,10 +28,8 @@ def test_vllm_gc_ed(): assert weak_llm() is None -@pytest.mark.skipif(is_hip() - and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER", - reason="Flashinfer does not support ROCm/HIP.") @pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["XFORMERS", "FLASH_ATTN", "FLASHINFER"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) @@ -40,10 +38,17 @@ def test_models( vllm_runner, example_prompts, model: str, + backend: str, dtype: str, max_tokens: int, enforce_eager: bool, ) -> None: + + if backend == "FLASHINFER" and is_hip(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + + os.environ["VLLM_ATTENTION_BACKEND"] = backend + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) From 5d8ddd4e0485828fe545c3f355f2eb47ec0bf6f4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Jul 2024 23:00:08 -0700 Subject: [PATCH 08/10] change order --- tests/basic_correctness/test_basic_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 876fd37c1c19..ec7c2ba3e3ce 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -29,7 +29,7 @@ def test_vllm_gc_ed(): @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["XFORMERS", "FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) From d4f851f816913b5f2d9937fb029fd270ad3d6e28 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 14 Jul 2024 09:04:35 -0700 Subject: [PATCH 09/10] fix gc --- vllm/model_executor/models/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index a60e8b376af5..33a0653a1564 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,4 @@ -from functools import lru_cache -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Dict import torch @@ -71,13 +70,22 @@ def make_layers( return start_layer, end_layer, modules -@lru_cache +# NOTE: don't use lru_cache here because it can prevent garbage collection +_model_to_pp_missing_layer_names: Dict[int, List[str]] = {} + + def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: """Get the names of the missing layers in a pipeline parallel model.""" + model_id = id(model) + if model_id in _model_to_pp_missing_layer_names: + return _model_to_pp_missing_layer_names[model_id] + missing_layer_names = [] for name, module in model.named_modules(): if isinstance(module, PPMissingLayer): missing_layer_names.append(name) + _model_to_pp_missing_layer_names[model_id] = missing_layer_names + return missing_layer_names From 50b91a82f990295b98efc2ae9fb103e2b863fa43 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 14 Jul 2024 09:14:55 -0700 Subject: [PATCH 10/10] isort --- vllm/model_executor/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 33a0653a1564..a0d2a0286ff6 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Dict +from typing import Callable, Dict, List, Tuple import torch