Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Yu <[email protected]>
  • Loading branch information
comaniac committed Feb 5, 2025
1 parent 9568f8c commit ec4e3ce
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ci/docker/llm.build.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ RUN <<EOF
set -ex

# TODO(comaniac): add other test dependencies here.
pip install -c python/requirements_compiled.txt pytest aiohttp pillow
pip install -c python/requirements_compiled.txt pytest aiohttp pillow "vllm==0.7.1"

EOF
4 changes: 2 additions & 2 deletions python/ray/llm/_internal/batch/stages/vllm_engine_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def __init__(
# Create an LLM engine.
self.llm = vLLMEngineWrapper(
model=self.model,
idx_in_batch_column=self.idx_in_batch_column,
idx_in_batch_column=self.IDX_IN_BATCH_COLUMN,
disable_log_stats=False,
max_pending_requests=self.max_pending_requests,
runtime_env=self.runtime_env,
Expand Down Expand Up @@ -423,7 +423,7 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
yield {
**output,
"request_id": request.request_id,
self.idx_in_batch_column: request.idx_in_batch,
self.IDX_IN_BATCH_COLUMN: request.idx_in_batch,
"batch_uuid": batch_uuid.hex,
"time_taken_llm": time_taken,
"params": str(request.params),
Expand Down
99 changes: 86 additions & 13 deletions python/ray/llm/tests/batch/stages/test_vllm_engine_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,79 @@
from ray.llm._internal.batch.stages.vllm_engine_stage import vLLMEngineWrapper


@pytest.fixture(scope="module")
def dummy_model_ckpt():
"""
This fixture creates a dummy model checkpoint for testing.
It uses the facebook/opt-125m model config and tokenizer to generate a dummy model.
The purpose of this is to avoid downloading the model from HuggingFace hub during
testing, which is flaky because of the rate limit and HF hub downtime.
"""
import os
import tempfile
import glob
import zipfile
from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM

model_config = {
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": ["OPTForCausalLM"],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": True,
"dropout": 0.1,
"eos_token_id": 2,
"ffn_dim": 3072,
"hidden_size": 768,
"init_std": 0.02,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"use_cache": True,
"vocab_size": 50272,
"word_embed_proj_dim": 768,
}

# Create a dummy model checkpoint.
config = OPTConfig(**model_config)
model = OPTForCausalLM(config)

# Load the tokenizer.
tokenizer_zip_path = os.path.join(
os.path.dirname(__file__), "../test_files/opt-tokenizer"
)
with tempfile.TemporaryDirectory() as tokenizer_path:
# Get all .z* files in the directory.
tokenizer_zip_paths = list(
glob.glob(os.path.join(tokenizer_zip_path, "opt-tokenizer.zip.*"))
)
full_zip_path = os.path.join(tokenizer_path, "opt-tokenizer_merged.zip")
with open(full_zip_path, "wb") as outfile:
for zip_path in sorted(tokenizer_zip_paths):
with open(zip_path, "rb") as infile:
outfile.write(infile.read())

with zipfile.ZipFile(full_zip_path, "r") as zip_ref:
zip_ref.extractall(tokenizer_path)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

# Create a temporary directory and save the model and tokenizer.
with tempfile.TemporaryDirectory() as checkpoint_dir:

config.save_pretrained(checkpoint_dir)
model.save_pretrained(checkpoint_dir)
tokenizer.save_pretrained(checkpoint_dir)

yield os.path.abspath(checkpoint_dir)


@pytest.fixture
def mock_vllm_wrapper():
with patch(
Expand Down Expand Up @@ -50,10 +123,10 @@ async def mock_generate(row):
yield mock_wrapper


def test_vllm_engine_stage_post_init():
def test_vllm_engine_stage_post_init(dummy_model_ckpt):
stage = vLLMEngineStage(
fn_constructor_kwargs=dict(
model="facebook/opt-125m",
model=dummy_model_ckpt,
engine_kwargs=dict(
tensor_parallel_size=4,
pipeline_parallel_size=2,
Expand All @@ -71,7 +144,7 @@ def test_vllm_engine_stage_post_init():
)

assert stage.fn_constructor_kwargs == {
"model": "facebook/opt-125m",
"model": dummy_model_ckpt,
"task_type": "generate",
"max_pending_requests": 10,
"engine_kwargs": {
Expand Down Expand Up @@ -100,11 +173,11 @@ def test_vllm_engine_stage_post_init():


@pytest.mark.asyncio
async def test_vllm_engine_udf_basic(mock_vllm_wrapper):
async def test_vllm_engine_udf_basic(mock_vllm_wrapper, dummy_model_ckpt):
# Create UDF instance - it will use the mocked wrapper
udf = vLLMEngineStageUDF(
data_column="__data",
model="facebook/opt-125m",
model=dummy_model_ckpt,
task_type="generate",
engine_kwargs={
# Test that this should be overridden by the stage.
Expand All @@ -115,7 +188,7 @@ async def test_vllm_engine_udf_basic(mock_vllm_wrapper):
},
)

assert udf.model == "facebook/opt-125m"
assert udf.model == dummy_model_ckpt
assert udf.task_type == "generate"
assert udf.engine_kwargs["task"] == "generate"
assert udf.engine_kwargs["max_num_seqs"] == 100
Expand Down Expand Up @@ -143,7 +216,7 @@ async def test_vllm_engine_udf_basic(mock_vllm_wrapper):

# Verify the wrapper was constructed with correct arguments
mock_vllm_wrapper.assert_called_once_with(
model="facebook/opt-125m",
model=dummy_model_ckpt,
idx_in_batch_column="__idx_in_batch",
disable_log_stats=False,
max_pending_requests=111,
Expand All @@ -161,7 +234,7 @@ async def test_vllm_engine_udf_basic(mock_vllm_wrapper):


@pytest.mark.asyncio
async def test_vllm_wrapper_pending_queue():
async def test_vllm_wrapper_pending_queue(dummy_model_ckpt):
from vllm.outputs import RequestOutput, CompletionOutput

max_pending_requests = 2
Expand Down Expand Up @@ -210,7 +283,7 @@ async def mock_generate(request):

# Create wrapper with max 2 pending requests
wrapper = vLLMEngineWrapper(
model="facebook/opt-125m",
model=dummy_model_ckpt,
idx_in_batch_column="__idx_in_batch",
disable_log_stats=True,
max_pending_requests=max_pending_requests,
Expand All @@ -231,14 +304,14 @@ async def mock_generate(request):

@pytest.mark.asyncio
@pytest.mark.parametrize("version", ["v0", "v1"])
async def test_vllm_wrapper_generate(version):
async def test_vllm_wrapper_generate(version, dummy_model_ckpt):
if version == "v1":
runtime_env = {"env": {"VLLM_USE_V1": "1"}}
else:
runtime_env = {}

wrapper = vLLMEngineWrapper(
model="facebook/opt-125m",
model=dummy_model_ckpt,
idx_in_batch_column="__idx_in_batch",
disable_log_stats=True,
max_pending_requests=10,
Expand Down Expand Up @@ -281,14 +354,14 @@ async def test_vllm_wrapper_generate(version):

@pytest.mark.asyncio
@pytest.mark.parametrize("version", ["v0", "v1"])
async def test_vllm_wrapper_embed(version):
async def test_vllm_wrapper_embed(version, dummy_model_ckpt):
if version == "v1":
runtime_env = {"env": {"VLLM_USE_V1": "1"}}
else:
runtime_env = {}

wrapper = vLLMEngineWrapper(
model="facebook/opt-125m",
model=dummy_model_ckpt,
idx_in_batch_column="__idx_in_batch",
disable_log_stats=True,
max_pending_requests=10,
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit ec4e3ce

Please sign in to comment.