Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import contextlib
import gc
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoProcessor, AutoTokenizer, BatchEncoding)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
Expand Down Expand Up @@ -144,16 +145,12 @@ def example_long_prompts() -> List[str]:
"float": torch.float,
}

AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)

_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)


class HfRunner:

def wrap_device(self, input: any):
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
return input.to("cuda")
else:
Expand All @@ -163,13 +160,16 @@ def __init__(
self,
model_name: str,
dtype: str = "half",
*,
is_embedding_model: bool = False,
is_vision_model: bool = False,
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]

self.model_name = model_name

if model_name in _EMBEDDING_MODELS:
if is_embedding_model:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
Expand All @@ -178,8 +178,13 @@ def __init__(
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
else:
auto_cls = AutoModelForCausalLM

self.model = self.wrap_device(
AutoModelForCausalLM.from_pretrained(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_models(
model: str,
dtype: str,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True)
hf_outputs = hf_model.encode(example_prompts)
del hf_model

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
"""
model_id, vision_language_config = model_and_config

hf_model = hf_runner(model_id, dtype=dtype)
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens,
images=hf_images)
Expand Down