Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/lora/test_lora_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import pytest

from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry


class DummyLoRAResolver(LoRAResolver):
"""A dummy LoRA resolver for testing."""

async def resolve_lora(self, lora_name: str) -> Optional[LoRARequest]:
if lora_name == "test_lora":
return LoRARequest(lora_name=lora_name,
lora_path="/dummy/path",
lora_int_id=abs(hash(lora_name)))
return None


def test_resolver_registry_registration():
"""Test basic resolver registration functionality."""
registry = LoRAResolverRegistry
resolver = DummyLoRAResolver()

# Register a new resolver
registry.register_resolver("dummy", resolver)
assert "dummy" in registry.get_supported_resolvers()

# Get registered resolver
retrieved_resolver = registry.get_resolver("dummy")
assert retrieved_resolver is resolver


def test_resolver_registry_duplicate_registration():
"""Test registering a resolver with an existing name."""
registry = LoRAResolverRegistry
resolver1 = DummyLoRAResolver()
resolver2 = DummyLoRAResolver()

registry.register_resolver("dummy", resolver1)
registry.register_resolver("dummy", resolver2)

assert registry.get_resolver("dummy") is resolver2


def test_resolver_registry_unknown_resolver():
"""Test getting a non-existent resolver."""
registry = LoRAResolverRegistry

with pytest.raises(KeyError, match="not found"):
registry.get_resolver("unknown_resolver")


@pytest.mark.asyncio
async def test_dummy_resolver_resolve():
"""Test the dummy resolver's resolve functionality."""
dummy_resolver = DummyLoRAResolver()

# Test successful resolution
result = await dummy_resolver.resolve_lora("test_lora")
assert isinstance(result, LoRARequest)
assert result.lora_name == "test_lora"
assert result.lora_path == "/dummy/path"

# Test failed resolution
result = await dummy_resolver.resolve_lora("nonexistent_lora")
assert result is None
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ async def _check_model(
lora.lora_name for lora in self.models.lora_requests
]:
return None
if request.model is not None and (await self.models.resolve_lora(
request.model)):
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.models.prompt_adapter_requests
Expand Down
55 changes: 55 additions & 0 deletions vllm/entrypoints/openai/serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import json
import pathlib
from asyncio import Lock
from collections import defaultdict
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union
Expand All @@ -15,6 +17,7 @@
UnloadLoRAAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter

Expand Down Expand Up @@ -68,6 +71,13 @@ def __init__(
self.lora_requests: list[LoRARequest] = []
self.lora_id_counter = AtomicCounter(0)

self.lora_resolvers: list[LoRAResolver] = []
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
):
self.lora_resolvers.append(
LoRAResolverRegistry.get_resolver(lora_resolver_name))
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)

self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
Expand Down Expand Up @@ -234,6 +244,51 @@ async def _check_unload_lora_adapter_request(

return None

async def resolve_lora(self, lora_name: str) -> Optional[LoRARequest]:
"""Attempt to resolve a LoRA adapter using available resolvers.

Args:
lora_name: Name/identifier of the LoRA adapter

Returns:
Optional[LoRARequest]: LoRA request if found, None otherwise
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
for existing in self.lora_requests:
if existing.lora_name == lora_name:
return existing

# Try to resolve using available resolvers
unique_id = abs(hash(lora_name))
for resolver in self.lora_resolvers:
lora_request = await resolver.resolve_lora(lora_name)

if lora_request is not None:
lora_request.lora_int_id = unique_id

try:
await self.engine_client.add_lora(lora_request)
# Successfully added, append and return
self.lora_requests.append(lora_request)
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name, resolver.__class__.__name__)
return lora_request
except BaseException as e:
# Log the error and try the next resolver
logger.warning(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver.", lora_name,
resolver.__class__.__name__, e)
continue # Try the next resolver

# If no resolver could successfully resolve and load the LoRA
logger.warning(
"Could not resolve or load LoRA adapter '%s' with any "
"available resolver.", lora_name)
return None


def create_error_response(
message: str,
Expand Down
81 changes: 81 additions & 0 deletions vllm/lora/resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import AbstractSet, Dict, Optional

from vllm.logger import init_logger
from vllm.lora.request import LoRARequest

logger = init_logger(__name__)


class LoRAResolver(ABC):
"""Base class for LoRA adapter resolvers.

This class defines the interface for resolving and fetching LoRA adapters.
Implementations of this class should handle the logic for locating and
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
etc.).
"""

@abstractmethod
async def resolve_lora(self, lora_name: str) -> Optional[LoRARequest]:
"""Abstract method to resolve and fetch a LoRA model adapter.

Implements logic to locate and download LoRA adapter based on the name.
Implementations might fetch from a blob storage or other sources.

Args:
lora_name: The name or identifier of the LoRA model to resolve.

Returns:
Optional[LoRARequest]: The resolved LoRA model information, or None
if the LoRA model cannot be found.
"""
pass


@dataclass
class _LoRAResolverRegistry:
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict)

def get_supported_resolvers(self) -> AbstractSet[str]:
"""Get all registered resolver names."""
return self.resolvers.keys()

def register_resolver(
self,
resolver_name: str,
resolver: LoRAResolver,
) -> None:
"""Register a LoRA resolver.
Args:
resolver_name: Name to register the resolver under.
resolver: The LoRA resolver instance to register.
"""
if resolver_name in self.resolvers:
logger.warning(
"LoRA resolver %s is already registered, and will be "
"overwritten by the new resolver instance %s.", resolver_name,
resolver)

self.resolvers[resolver_name] = resolver

def get_resolver(self, resolver_name: str) -> LoRAResolver:
"""Get a registered resolver instance by name.
Args:
resolver_name: Name of the resolver to get.
Returns:
The resolver instance.
Raises:
KeyError: If the resolver is not found in the registry.
"""
if resolver_name not in self.resolvers:
raise KeyError(
f"LoRA resolver '{resolver_name}' not found. "
f"Available resolvers: {list(self.resolvers.keys())}")
return self.resolvers[resolver_name]


LoRAResolverRegistry = _LoRAResolverRegistry()
Loading