From d3c002eadcbaabe1cc2e5fe94321cdc6383cd4e3 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Thu, 22 Aug 2024 01:33:35 +0800 Subject: [PATCH 01/16] [Bugfix] chat method add_generation_prompt param (#7734) --- vllm/entrypoints/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 372e96e3716aa..31175724c6c79 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -353,7 +353,7 @@ def chat( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, - add_generation_template: bool = True, + add_generation_prompt: bool = True, ) -> List[RequestOutput]: """ Generates responses for chat messages. @@ -374,7 +374,7 @@ def chat( lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. - add_generation_template: If True, adds a generation template + add_generation_prompt: If True, adds a generation template to each message. Returns: @@ -392,7 +392,7 @@ def chat( tokenizer, conversations, chat_template=chat_template, - add_generation_template=add_generation_template) + add_generation_prompt=add_generation_prompt) return self.generate( prompts, From f7e3b0c5aa2862d27b8872084c5ca934659ceef8 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 21 Aug 2024 13:34:14 -0400 Subject: [PATCH 02/16] [Bugfix][Frontend] Fix Issues Under High Load With `zeromq` Frontend (#7394) Co-authored-by: Nick Hill --- .buildkite/test-pipeline.yaml | 1 + tests/entrypoints/openai/test_accuracy.py | 55 +++++ vllm/engine/async_llm_engine.py | 5 + vllm/engine/protocol.py | 4 + vllm/entrypoints/launcher.py | 9 + vllm/entrypoints/openai/api_server.py | 11 +- vllm/entrypoints/openai/rpc/__init__.py | 14 +- vllm/entrypoints/openai/rpc/client.py | 248 ++++++++++++++++------ vllm/entrypoints/openai/rpc/server.py | 116 +++++----- 9 files changed, 322 insertions(+), 141 deletions(-) create mode 100644 tests/entrypoints/openai/test_accuracy.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 59d7241bd452d..aa90145705f9d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -86,6 +86,7 @@ steps: - vllm/ commands: - pip install -e ./plugins/vllm_add_dummy_model + - pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api] - pytest -v -s entrypoints/llm - pytest -v -s entrypoints/openai diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py new file mode 100644 index 0000000000000..b442a903c33ae --- /dev/null +++ b/tests/entrypoints/openai/test_accuracy.py @@ -0,0 +1,55 @@ +""" +This file test accuracy of the vLLM server via LMEval. +It uses local-completions, which interacts with vLLM +through the OAI API with N concurrent connections. +This simulates real work usage of the API and makes +sure that the zmq frontend mp RPC message passing and +AsyncLLMEngine are working correctly. +""" + +import lm_eval +import pytest + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +NUM_CONCURRENT = 500 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 +EXPECTED_VALUE = 0.58 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", "4096", "--enable-chunked-prefill", + "--disable-log-requests", "--enforce-eager" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def server_data(server): + return { + "url": f"{server.url_for('v1')}/completions", + } + + +def test_lm_eval_accuracy(server_data): + model_args = (f"model={MODEL_NAME}," + f"base_url={server_data['url']}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ceda0b83a2397..9911cc9bdd84f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -766,6 +766,11 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None + @property + def limit_concurrency(self) -> Optional[int]: + """Maximum number of concurrently running requests.""" + return None + def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index cb16775a1cd59..6c7fd96a7f8e5 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -29,6 +29,10 @@ def is_stopped(self) -> bool: def errored(self) -> bool: ... + @property + def limit_concurrency(self) -> Optional[int]: + """Maximum number of concurrently running requests.""" + def generate( self, inputs: PromptInputs, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index f4a9c61a431c1..3598872b65bb0 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -27,6 +27,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + # Set concurrency limits in uvicorn if running in multiprocessing mode + # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). + if engine.limit_concurrency is not None: + logger.info( + "Launching Uvicorn with --limit_concurrency %s. To avoid this " + "limit at the expense of performance run with " + "--disable-frontend-multiprocessing", engine.limit_concurrency) + uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency + config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server, engine) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f37c7f4d91d57..266bf79dcdd65 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -135,6 +135,12 @@ async def build_async_engine_client( logger.info("Multiprocessing frontend to use %s for RPC Path.", rpc_path) + # Build RPCClient, which conforms to AsyncEngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + rpc_client = AsyncEngineRPCClient(rpc_path) + async_engine_client = rpc_client # type: ignore + # Start RPCServer in separate process (holds the AsyncLLMEngine). context = multiprocessing.get_context("spawn") # the current process might have CUDA context, @@ -145,11 +151,6 @@ async def build_async_engine_client( rpc_server_process.start() logger.info("Started engine process with PID %d", rpc_server_process.pid) - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) - async_engine_client = rpc_client # type: ignore try: while True: diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 8a7b12201cab7..981dfbfc6670e 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -7,8 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +# Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" -VLLM_RPC_HEALTHY_STR = "HEALTHY" + +# Timeouts. +VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000 +VLLM_RPC_HEALTH_TIMEOUT_MS = 10000 + +# Minimum value of ZMQ.SOCKET_LIMIT to run mp. +VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 + +# HWM is set to Infinity. +VLLM_RPC_ZMQ_HWM = 0 @dataclass @@ -34,7 +44,7 @@ class RPCUtilityRequest(Enum): GET_SCHEDULER_CONFIG = 5 GET_LORA_CONFIG = 6 DO_LOG_STATS = 7 - CHECK_HEALTH = 8 + IS_SERVER_HEALTHY = 8 IS_TRACING_ENABLED = 9 diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 64a20b33d8f3e..7e360d1defb10 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,5 +1,7 @@ +import asyncio from contextlib import contextmanager from typing import Any, AsyncGenerator, Mapping, Optional +from uuid import uuid4 import cloudpickle import zmq @@ -7,32 +9,140 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) +# yapf: disable from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_HEALTHY_STR, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + VLLM_RPC_HEALTH_TIMEOUT_MS, + VLLM_RPC_SERVER_START_TIMEOUT_MS, + VLLM_RPC_SOCKET_LIMIT_CUTOFF, + VLLM_RPC_SUCCESS_STR, + VLLM_RPC_ZMQ_HWM, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) +# yapf: enable from vllm.inputs import PromptInputs +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -# Time to wait before checking it the server process is alive. -SERVER_START_TIMEOUT_MS = 1000 +logger = init_logger(__name__) + +# Path used for inprocess proxy. +INPROC_PROXY_PATH = f"inproc://{uuid4()}" class AsyncEngineRPCClient: + """ + RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. + + The overall design mirrors the Asynchronous Client Server Pattern + https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern + + On startup, the RPCClient: + - makes DEALER socket (to_rpc_server) that connects to the RPCServer + via ipc, which uses unix sockets under the hood + (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) + - makes ROUTER socket (from_api_server) that binds to a random + inproc address, which uses memory under the hood + (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) + - runs a proxy in a background asyncio task between + from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) + + Each request handled by the asyncio api_server calls generate(): + - make a DEALER socket that connects to from_api_server via inproc + - send a RCPGenerateRequest to the inproc socket + - background proxy forwards the request from inproc -> ipc + - RPCServer responds to the request one token at a time over ipc + - background proxy forwards the response from ipc -> inproc + + The connection looks like this: + DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER + + Message routing is performed via identities that are managed by the + ROUTER socket. ROUTER sockets track every connection it has and + tells the caller about these. The way it tells the caller is to stick + the connection identity in front of each message received. When we + send the message via a ROUTER, we first send an identity frame. + See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope + for more details on connection identities. + + This proxy design enables us to use a single unix socket, which + improves performance by avoiding syscalls (~5%) and avoids resource limits + such as ulimit, which defaults to 1024 on ubuntu. + + Note: we run set_hwm(0) on each socket, which sets the HWM to inf, + which is required to avoid dropping messages under high load. + This is generally not advisable. However, since we are in control + of both sides of the connection + failure on either side is + catastrophic to the overall system health and memory profiling + suggests limited memory overhead relative to asyncio, we will + proceed for now. + + See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks + for more details on high water marks. + """ def __init__(self, rpc_path: str): self.context = zmq.asyncio.Context() - self.rpc_path = rpc_path + + # Maximum number of sockets that can be opened (typically 65536). + # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) + socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) + if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: + raise ValueError( + f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " + "the number of concurrent requests vLLM can process. Launch " + "vLLM with --disable-frontend-multiprocessing and open a " + "GitHub issue so we can investigate.") + + # We only have 1 ipc connection that uses unix sockets, so + # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will + # not run into ulimit issues) + self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) + + # IPC connection to RPC Server (uses unix sockets). + self.to_rpc_server = self.context.socket(zmq.constants.DEALER) + self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) + self.to_rpc_server.bind(rpc_path) + + # In process proxy to RPC Server (uses memory-based messaging). + self.from_api_server = self.context.socket(zmq.constants.ROUTER) + self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) + self.from_api_server.bind(INPROC_PROXY_PATH) + + # Asyncio background task for the proxy. + self.proxy_task = asyncio.create_task( + self.run_proxy(self.from_api_server, self.to_rpc_server)) + + # Since we open 1 inproc socket per request, we have a hard cap on + # the number of requests that can run in vLLM w. frontend + # mulitprocessing. This value is used uvicorn to launch + # with --limit-concurrency to return 503 when server is overloaded. + # We need 2 sockets per request - 2: + # 1 for generate(), 1 for abort(), do_log_stats(), check_health() + self.limit_concurrency = socket_limit // 2 - 2 + + async def run_proxy(self, socket_from, socket_to): + """Background task that runs a proxy""" + poller = zmq.asyncio.Poller() + poller.register(socket_from, zmq.constants.POLLIN) + poller.register(socket_to, zmq.constants.POLLIN) + while True: + events = await poller.poll() + events = dict(events) + if socket_from in events: + identity, msg = await socket_from.recv_multipart() + await socket_to.send_multipart([identity, msg]) + if socket_to in events: + identity, msg = await socket_to.recv_multipart() + await socket_from.send_multipart([identity, msg]) async def setup(self): """Setup the client before it starts sending server requests.""" # Wait until server is ready. - await self.wait_for_server() + await self._wait_for_server_rpc() self._errored = False # Get the configs. @@ -51,29 +161,23 @@ async def setup(self): def close(self): """Destroy the ZeroMQ Context.""" + # Close all sockets associated with this context and + # then terminate the context. + self.from_api_server.close() + self.to_rpc_server.close() self.context.destroy() @contextmanager - def socket(self): - # Ensure client sockets are always closed after use - - # Connect to RPC socket for Request-Reply pattern, + def to_proxy_socket(self): + # Connect to the RPCServer via the proxy. # Note that we use DEALER to enable asynchronous communication # to enable streaming. socket = self.context.socket(zmq.constants.DEALER) + socket.set_hwm(VLLM_RPC_ZMQ_HWM) try: - socket.connect(self.rpc_path) + socket.connect(INPROC_PROXY_PATH) yield socket finally: - # linger == 0 means discard unsent messages - # when the socket is closed. This is necessary - # because otherwise self.context.destroy() will - # wait for 30 seconds until unsent messages are - # received, which is impossible if the server - # crashed. In the absence of a server crash we - # always expect a response before closing the - # socket anyway. - # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24 socket.close(linger=0) async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, @@ -81,10 +185,9 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" - with self.socket() as socket: - + with self.to_proxy_socket() as socket: # Ping RPCServer with a request. - await socket.send(cloudpickle.dumps(request)) + await socket.send_multipart([cloudpickle.dumps(request)]) # Await the data from the Server. data = cloudpickle.loads(await socket.recv()) @@ -93,31 +196,48 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, # LoRAConfig can be None. if expected_type == LoRAConfig and data is None: pass + elif isinstance(data, Exception): + logger.error(error_message) + raise data else: raise ValueError(error_message) return data - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - timeout: Optional[int] = None): + async def _send_one_way_rpc_request( + self, + request: RPC_REQUEST_TYPE, + error_message: str, + timeout: Optional[int] = None, + socket: Optional[zmq.asyncio.Socket] = None): """Send one-way RPC request to trigger an action.""" - with self.socket() as socket: - # Ping RPC Server with request. - await socket.send(cloudpickle.dumps(request)) - # Await acknowledgement from RPCServer. + async def do_rpc_call(socket: zmq.asyncio.Socket, + request: RPC_REQUEST_TYPE, + timeout=None): + + await socket.send_multipart([cloudpickle.dumps(request)]) + if timeout is not None and await socket.poll(timeout=timeout) == 0: - raise TimeoutError(f"server didn't reply within {timeout} ms") + raise TimeoutError(f"Server didn't reply within {timeout} ms") + + return cloudpickle.loads(await socket.recv()) - response = cloudpickle.loads(await socket.recv()) + # Make a new socket connection. + if socket is None: + with self.to_proxy_socket() as socket: + response = await do_rpc_call(socket, request, timeout) + + # Use existing socket connection. + else: + response = await do_rpc_call(socket, request, timeout) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + if isinstance(response, Exception): + logger.error(error_message) + raise response raise ValueError(error_message) - return response - async def get_tokenizer(self, lora_request: LoRARequest): return await self.tokenizer.get_lora_tokenizer_async(lora_request) @@ -130,13 +250,13 @@ async def get_model_config(self) -> ModelConfig: async def is_tracing_enabled(self) -> bool: return self.tracing_flag - async def wait_for_server(self): + async def _wait_for_server_rpc(self): """Wait for the RPCServer to start up.""" await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server.", - timeout=SERVER_START_TIMEOUT_MS) + error_message="Unable to start RPC Server", + timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS) async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -184,8 +304,7 @@ async def _is_tracing_enabled_rpc(self) -> bool: return await self._send_get_data_rpc_request( RPCUtilityRequest.IS_TRACING_ENABLED, expected_type=bool, - error_message="Could not get is_tracing_enabled flag from RPC " - "Server") + error_message="Could not get is_tracing_enabled from RPC Server") async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" @@ -226,8 +345,7 @@ async def generate( finished = False try: - with self.socket() as socket: - + with self.to_proxy_socket() as socket: # Send RPCGenerateRequest to the RPCServer. await socket.send_multipart([ cloudpickle.dumps( @@ -246,43 +364,37 @@ async def generate( request_output = cloudpickle.loads(message) if isinstance(request_output, Exception): - # On exception, check if the server is still healthy. - # Use this to set the sync `is_running` and `errored` - # properties. - try: - await self.check_health() - except Exception: - self._errored = True + # On exception, check if the server is still healthy + # possibly setting the `errored` property. + if not self._errored: + try: + await self.check_health(socket=socket) + except Exception as e: + self._errored = True + logger.exception(repr(e)) + # NB: do before raising here so that the flag is set # by the time the caller receives this exception raise request_output finished = request_output.finished yield request_output + finally: - if not finished: + # Request was canceled by the client. + if not finished and not self._errored: await self.abort(request_id) - async def check_health(self) -> None: + async def check_health(self, + socket: Optional[zmq.asyncio.Socket] = None + ) -> None: """Raise if unhealthy""" - with self.socket() as socket: - - # Ping RPCServer with CHECK_HEALTH request. - await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) - ) - - # Await the reply from the server. - # TODO: do we need an internal timeout here? - # Or do we expect the external probe to timeout and let this chill? - health_message = cloudpickle.loads(await socket.recv()) - - if isinstance(health_message, Exception): - raise health_message - - if health_message != VLLM_RPC_HEALTHY_STR: - raise ValueError("Expected healthy response from backend but got " - "f{health_message}") + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_HEALTHY, + error_message="Got Unhealthy response from RPC Server", + timeout=VLLM_RPC_HEALTH_TIMEOUT_MS, + socket=socket) async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 770ee77926df9..580b83277cfbe 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,6 +1,6 @@ import asyncio import signal -from typing import Any, Coroutine +from typing import Any, Coroutine, Union import cloudpickle import uvloop @@ -9,14 +9,19 @@ from typing_extensions import Never from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, + VLLM_RPC_ZMQ_HWM, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + class AsyncEngineRPCServer: @@ -29,9 +34,10 @@ def __init__(self, async_engine_args: AsyncEngineArgs, # Initialize context. self.context = zmq.asyncio.Context() - # Init socket for readiness state. - self.socket = self.context.socket(zmq.constants.ROUTER) - self.socket.bind(rpc_path) + # Init socket. + self.socket = self.context.socket(zmq.constants.DEALER) + self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) + self.socket.connect(rpc_path) def cleanup(self): """Cleanup all resources.""" @@ -41,39 +47,27 @@ def cleanup(self): # Clear the engine reference so that it can be GC'ed. del self.engine - async def get_model_config(self, identity): - """Send the ModelConfig""" - model_config = await self.engine.get_model_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(model_config)]) - - async def get_decoding_config(self, identity): - """Send the DecodingConfig""" - decoding_config = await self.engine.get_decoding_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(decoding_config)]) - - async def get_lora_config(self, identity): - lora_config = await self.engine.get_lora_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(lora_config)]) - - async def get_scheduler_config(self, identity): - """Send the SchedulerConfig""" - parallel_config = await self.engine.get_scheduler_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(parallel_config)]) + async def get_config(self, identity, request): + try: + config: CONFIG_TYPE + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + config = await self.engine.get_model_config() + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + config = await self.engine.get_decoding_config() + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + config = await self.engine.get_lora_config() + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + config = await self.engine.get_scheduler_config() + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + config = await self.engine.get_parallel_config() + else: + raise ValueError("Unknown Config Request: %s", request) - async def get_parallel_config(self, identity): - """Send the ParallelConfig""" - parallel_config = await self.engine.get_parallel_config() + await self.socket.send_multipart( + [identity, cloudpickle.dumps(config)]) - await self.socket.send_multipart( - [identity, cloudpickle.dumps(parallel_config)]) + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) async def is_tracing_enabled(self, identity): """Send the is_tracing_enabled flag""" @@ -86,31 +80,23 @@ async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) async def is_server_ready(self, identity): """Notify the client that we are ready.""" - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" try: # Abort the request in the llm engine. await self.engine.abort(request.request_id) - except Exception: - logger.warning("Failed to abort request %s", request.request_id) - finally: - # Send confirmation to the client. - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR + except Exception as e: + result = e + await self.socket.send_multipart([identity, cloudpickle.dumps(result)]) async def generate(self, identity, generate_request: RPCGenerateRequest): try: @@ -127,14 +113,14 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): [identity, cloudpickle.dumps(request_output)]) except Exception as e: - ### Notify client of all failures await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) + [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + except Exception as e: await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) @@ -151,21 +137,19 @@ def _make_handler_coro(self, identity, return self.abort(identity, request) elif isinstance(request, RPCUtilityRequest): - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - return self.get_model_config(identity) - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - return self.get_parallel_config(identity) - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - return self.get_decoding_config(identity) - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - return self.get_scheduler_config(identity) - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - return self.get_lora_config(identity) + if request in [ + RPCUtilityRequest.GET_MODEL_CONFIG, + RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCUtilityRequest.GET_DECODING_CONFIG, + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCUtilityRequest.GET_LORA_CONFIG + ]: + return self.get_config(identity, request) elif request == RPCUtilityRequest.DO_LOG_STATS: return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: return self.is_server_ready(identity) - elif request == RPCUtilityRequest.CHECK_HEALTH: + elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: return self.check_health(identity) elif request == RPCUtilityRequest.IS_TRACING_ENABLED: return self.is_tracing_enabled(identity) From 1b32e0264888200a0e6187496a816ef597a7f320 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 21 Aug 2024 18:17:48 +0000 Subject: [PATCH 03/16] [Bugfix] Pass PYTHONPATH from setup.py to CMake (#7730) --- CMakeLists.txt | 2 +- setup.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c8d4aaeda9091..217dc70c4b24e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -233,7 +233,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Generate sources: execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py RESULT_VARIABLE machete_generation_result OUTPUT_VARIABLE machete_generation_output diff --git a/setup.py b/setup.py index ef599b613667b..21b0422c0f0bd 100644 --- a/setup.py +++ b/setup.py @@ -184,6 +184,10 @@ def configure(self, ext: CMakeExtension) -> None: # match. cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] + # Pass the python path to cmake so it can reuse the build dependencies + # on subsequent calls to python. + cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))] + # # Setup parallelism and build tool # From 91f4522cbf85df0a65f619b25f6751edf2d5f0d6 Mon Sep 17 00:00:00 2001 From: William Lin Date: Wed, 21 Aug 2024 11:49:19 -0700 Subject: [PATCH 04/16] [multi-step] Raise error if not using async engine (#7703) --- vllm/engine/llm_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94aed6b8c50c7..f72902c372181 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1302,6 +1302,11 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: raise NotImplementedError( "Pipeline parallelism is only supported through AsyncLLMEngine " "as performance will be severely degraded otherwise.") + + if self.scheduler_config.num_scheduler_steps > 1: + raise NotImplementedError( + "Multiple scheduler steps (multi-step) are only supported " + "through AsyncLLMEngine. ") seq_group_metadata_list, scheduler_outputs = self.scheduler[ 0].schedule() From 970dfdc01d3453c83066e6156278d70bade0350c Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 21 Aug 2024 15:53:01 -0400 Subject: [PATCH 05/16] [Frontend] Improve Startup Failure UX (#7716) --- .../entrypoints/openai/test_mp_api_server.py | 29 ++++++++++--------- vllm/entrypoints/openai/api_server.py | 27 +++++++++++++---- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/tests/entrypoints/openai/test_mp_api_server.py b/tests/entrypoints/openai/test_mp_api_server.py index b9fc0c1422b74..fbfe0db19dd03 100644 --- a/tests/entrypoints/openai/test_mp_api_server.py +++ b/tests/entrypoints/openai/test_mp_api_server.py @@ -1,3 +1,5 @@ +import time + import pytest from vllm.entrypoints.openai.api_server import build_async_engine_client @@ -8,19 +10,20 @@ @pytest.mark.asyncio async def test_mp_crash_detection(): - with pytest.raises(RuntimeError) as excinfo: - parser = FlexibleArgumentParser( - description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - # use an invalid tensor_parallel_size to trigger the - # error in the server - args.tensor_parallel_size = 65536 - - async with build_async_engine_client(args): - pass - assert "The server process died before responding to the readiness probe"\ - in str(excinfo.value) + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + # use an invalid tensor_parallel_size to trigger the + # error in the server + args.tensor_parallel_size = 65536 + + start = time.perf_counter() + async with build_async_engine_client(args): + pass + end = time.perf_counter() + + assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " + "if there is an error in the startup.") @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 266bf79dcdd65..94d8525e429ca 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -8,7 +8,7 @@ from argparse import Namespace from contextlib import asynccontextmanager from http import HTTPStatus -from typing import AsyncIterator, Set +from typing import AsyncIterator, Optional, Set from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError @@ -60,6 +60,7 @@ openai_serving_tokenization: OpenAIServingTokenization prometheus_multiproc_dir: tempfile.TemporaryDirectory +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.openai.api_server') _running_tasks: Set[asyncio.Task] = set() @@ -94,7 +95,15 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[AsyncEngineClient]: + args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + """ + Create AsyncEngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit global engine_args @@ -157,11 +166,13 @@ async def build_async_engine_client( try: await rpc_client.setup() break - except TimeoutError as e: + except TimeoutError: if not rpc_server_process.is_alive(): - raise RuntimeError( - "The server process died before " - "responding to the readiness probe") from e + logger.error( + "RPCServer process died before responding " + "to readiness probe") + yield None + return yield async_engine_client finally: @@ -410,6 +421,10 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("args: %s", args) async with build_async_engine_client(args) as async_engine_client: + # If None, creation of the client failed and we exit. + if async_engine_client is None: + return + app = await init_app(async_engine_client, args) shutdown_task = await serve_http( From dd53c4b023056cda6174cc32dc3d31bc01e8646a Mon Sep 17 00:00:00 2001 From: William Lin Date: Wed, 21 Aug 2024 15:39:26 -0700 Subject: [PATCH 06/16] [misc] Add Torch profiler support (#7451) Co-authored-by: Cody Yu --- benchmarks/backend_request_func.py | 4 +- benchmarks/benchmark_serving.py | 43 +++++++++++++++++++ docs/source/dev/profiling/profiling_index.rst | 33 ++++++++++++++ docs/source/index.rst | 1 + vllm/engine/async_llm_engine.py | 6 +++ vllm/engine/protocol.py | 8 ++++ vllm/entrypoints/openai/api_server.py | 20 +++++++++ vllm/entrypoints/openai/rpc/__init__.py | 2 + vllm/entrypoints/openai/rpc/client.py | 14 ++++++ vllm/entrypoints/openai/rpc/server.py | 24 +++++++++++ vllm/envs.py | 7 +++ vllm/worker/worker.py | 31 +++++++++++++ 12 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 docs/source/dev/profiling/profiling_index.rst diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 3b4e31eaa712e..f7d67692f697b 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -225,8 +225,8 @@ async def async_request_openai_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "completions" - ), "OpenAI Completions API URL must end with 'completions'." + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fc0dbf77f16b9..fe687da492901 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -295,6 +295,7 @@ def calculate_metrics( async def benchmark( backend: str, api_url: str, + base_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], @@ -302,6 +303,7 @@ async def benchmark( use_beam_search: bool, request_rate: float, disable_tqdm: bool, + profile: bool, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -326,6 +328,22 @@ async def benchmark( f"are correctly specified. Error: {test_output.error}") else: print("Initial test run completed. Starting main benchmark run...") + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -349,6 +367,21 @@ async def benchmark( pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + if pbar is not None: pbar.close() @@ -433,8 +466,10 @@ def main(args: argparse.Namespace): if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" else: api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) @@ -506,6 +541,7 @@ def main(args: argparse.Namespace): benchmark( backend=backend, api_url=api_url, + base_url=base_url, model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, @@ -513,6 +549,7 @@ def main(args: argparse.Namespace): use_beam_search=args.use_beam_search, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, + profile=args.profile, )) # Save config and results to json @@ -693,6 +730,12 @@ def main(args: argparse.Namespace): action="store_true", help="Specify to disable tqdm progress bar.", ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--save-result", action="store_true", diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst new file mode 100644 index 0000000000000..af3c78c3b5a55 --- /dev/null +++ b/docs/source/dev/profiling/profiling_index.rst @@ -0,0 +1,33 @@ +Profiling vLLM +================================= + +We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/`` + +The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set. + +When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag. + +.. warning:: + + Only enable profiling in a development environment. + + +Traces can be visualized using https://ui.perfetto.dev/. + +.. tip:: + + Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. + +Example commands: + +OpenAI Server: + +.. code-block:: bash + + VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B + +benchmark_serving.py: + +.. code-block:: bash + + python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2 \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 4e79871e6e78f..4b817c4ba9498 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -136,6 +136,7 @@ Documentation dev/input_processing/model_inputs_index dev/multimodal/multimodal_index dev/dockerfile/dockerfile + dev/profiling/profiling_index .. toctree:: :maxdepth: 1 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9911cc9bdd84f..8812b853c0665 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1266,3 +1266,9 @@ def remove_logger(self, logger_name: str) -> None: logger_name=logger_name)) else: self.engine.remove_logger(logger_name=logger_name) + + async def start_profile(self) -> None: + self.engine.model_executor._run_workers("start_profile") + + async def stop_profile(self) -> None: + self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6c7fd96a7f8e5..1deb75167bc72 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -91,3 +91,11 @@ async def do_log_stats( async def check_health(self) -> None: """Raise if unhealthy""" ... + + async def start_profile(self) -> None: + """Start profiling the engine""" + ... + + async def stop_profile(self) -> None: + """Start profiling the engine""" + ... diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 94d8525e429ca..8e8371ef1559a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(): + logger.info("Starting profiler...") + await async_engine_client.start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(): + logger.info("Stopping profiler...") + await async_engine_client.stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + def build_app(args: Namespace) -> FastAPI: app = FastAPI(lifespan=lifespan) app.include_router(router) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 981dfbfc6670e..571dca5f61fa4 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum): DO_LOG_STATS = 7 IS_SERVER_HEALTHY = 8 IS_TRACING_ENABLED = 9 + START_PROFILE = 10 + STOP_PROFILE = 11 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 7e360d1defb10..1f26348c74d6d 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -400,3 +400,17 @@ async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( "Embeddings not supported with multiprocessing backend") + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.START_PROFILE, + error_message="RPCRequest START_PROFILE failed.") + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.STOP_PROFILE, + error_message="RPCRequest STOP_PROFILE failed.") \ No newline at end of file diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 580b83277cfbe..738d12bbef051 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -124,6 +124,26 @@ async def check_health(self, identity): except Exception as e: await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + async def start_profile(self, identity): + logger.info("Starting profiler...") + await self.engine.start_profile() + logger.info("Profiler started.") + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def stop_profile(self, identity): + logger.info("Stopping profiler...") + await self.engine.stop_profile() + logger.info("Profiler stopped.") + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" @@ -153,6 +173,10 @@ def _make_handler_coro(self, identity, return self.check_health(identity) elif request == RPCUtilityRequest.IS_TRACING_ENABLED: return self.is_tracing_enabled(identity) + elif request == RPCUtilityRequest.START_PROFILE: + return self.start_profile(identity) + elif request == RPCUtilityRequest.STOP_PROFILE: + return self.stop_profile(identity) else: raise ValueError(f"Unknown RPCUtilityRequest type: {request}") diff --git a/vllm/envs.py b/vllm/envs.py index 115ead01f537d..e4cf6a028ac18 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -58,6 +58,7 @@ VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None + VLLM_TORCH_PROFILER_DIR: Optional[str] = None def get_default_cache_root(): @@ -384,6 +385,12 @@ def get_default_config_root(): "VLLM_PLUGINS": lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ "VLLM_PLUGINS"].split(","), + + # Enables torch profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "VLLM_TORCH_PROFILER_DIR": + lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os + .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), } # end-env-vars-definition diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 97be68934be46..331a805caba9a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -6,6 +6,7 @@ import torch import torch.distributed +import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, @@ -13,6 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -27,6 +29,8 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +logger = init_logger(__name__) + class Worker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -113,6 +117,33 @@ def __init__( self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + def _is_encoder_decoder_model(self): return self.model_config.is_encoder_decoder_model From 1ca0d4f86bd6db76bd601df16647fed53e495a0a Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 21 Aug 2024 15:49:39 -0700 Subject: [PATCH 07/16] [Model] Add UltravoxModel and UltravoxConfig (#7615) --- docs/source/models/supported_models.rst | 7 +- examples/offline_inference_audio_language.py | 97 ++++ examples/openai_audio_api_client.py | 90 ++++ tests/conftest.py | 31 +- ...t_basic_distributed_correctness_enc_dec.py | 3 +- tests/entrypoints/openai/test_audio.py | 148 +----- tests/models/test_bart.py | 3 +- tests/models/test_blip2.py | 5 +- tests/models/test_chameleon.py | 4 +- tests/models/test_llava.py | 5 +- tests/models/test_llava_image_embeds.py | 5 +- tests/models/test_llava_next.py | 5 +- tests/models/test_paligemma.py | 5 +- tests/models/test_qwen.py | 2 +- tests/models/test_ultravox.py | 151 ++++++ vllm/assets/audio.py | 26 ++ vllm/entrypoints/chat_utils.py | 6 +- vllm/model_executor/models/__init__.py | 3 +- vllm/model_executor/models/blip.py | 8 +- vllm/model_executor/models/chameleon.py | 8 +- vllm/model_executor/models/clip.py | 8 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/paligemma.py | 2 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/siglip.py | 8 +- vllm/model_executor/models/ultravox.py | 435 ++++++++++++++++++ vllm/multimodal/image.py | 83 ---- vllm/multimodal/utils.py | 90 +++- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/ultravox.py | 99 ++++ 33 files changed, 1090 insertions(+), 264 deletions(-) create mode 100644 examples/offline_inference_audio_language.py create mode 100644 examples/openai_audio_api_client.py create mode 100644 tests/models/test_ultravox.py create mode 100644 vllm/assets/audio.py create mode 100644 vllm/model_executor/models/ultravox.py create mode 100644 vllm/transformers_utils/configs/ultravox.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c761d1b32cd91..1692e13c4ec06 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -186,7 +186,7 @@ Multimodal Language Models * - Architecture - Models - - Supported Modality(ies) + - Supported Modalities - Example HuggingFace Models - :ref:`LoRA ` * - :code:`Blip2ForConditionalGeneration` @@ -234,6 +234,11 @@ Multimodal Language Models - Image - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code: `UltravoxModel` + - Ultravox + - Audio + - :code: `fixie-ai/ultravox-v0_3` + - .. note:: For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py new file mode 100644 index 0000000000000..7b886f8e2001a --- /dev/null +++ b/examples/offline_inference_audio_language.py @@ -0,0 +1,97 @@ +""" +This example shows how to use vLLM for running offline inference +with the correct prompt format on vision language models. + +For most models, the prompt format should follow corresponding examples +on HuggingFace model repository. +""" +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.utils import FlexibleArgumentParser + +# Input audio and question +audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate +question = "What is recited in the audio?" + + +# Ultravox 0.3 +def run_ultravox(question): + model_name = "fixie-ai/ultravox-v0_3" + + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [{ + 'role': 'user', + 'content': f"<|reserved_special_token_0|>\n{question}" + }] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + llm = LLM(model=model_name) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +model_example_map = { + "ultravox": run_ultravox, +} + + +def main(args): + model = args.model_type + if model not in model_example_map: + raise ValueError(f"Model type {model} is not supported.") + + llm, prompt, stop_token_ids = model_example_map[model](question) + + # We set temperature to 0.2 so that outputs can be different + # even when all prompts are identical when running batch inference. + sampling_params = SamplingParams(temperature=0.2, + max_tokens=64, + stop_token_ids=stop_token_ids) + + assert args.num_prompts > 0 + if args.num_prompts == 1: + # Single inference + inputs = { + "prompt": prompt, + "multi_modal_data": { + "audio": audio_and_sample_rate + }, + } + + else: + # Batch inference + inputs = [{ + "prompt": prompt, + "multi_modal_data": { + "audio": audio_and_sample_rate + }, + } for _ in range(args.num_prompts)] + + outputs = llm.generate(inputs, sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'audio language models') + parser.add_argument('--model-type', + '-m', + type=str, + default="ultravox", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument('--num-prompts', + type=int, + default=1, + help='Number of prompts to run.') + + args = parser.parse_args() + main(args) diff --git a/examples/openai_audio_api_client.py b/examples/openai_audio_api_client.py new file mode 100644 index 0000000000000..80a972683871f --- /dev/null +++ b/examples/openai_audio_api_client.py @@ -0,0 +1,90 @@ +"""An example showing how to use vLLM to serve VLMs. + +Launch the vLLM server with the following command: +vllm serve fixie-ai/ultravox-v0_3 +""" +import base64 + +import requests +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +# Any format supported by librosa is supported +audio_url = AudioAsset("winning_call").url + +# Use audio url in the payload +chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + }, + ], + }], + model=model, + max_tokens=64, +) + +result = chat_completion_from_url.choices[0].message.content +print(f"Chat completion output:{result}") + + +# Use base64 encoded audio in the payload +def encode_audio_base64_from_url(audio_url: str) -> str: + """Encode an audio retrieved from a remote url to base64 format.""" + + with requests.get(audio_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode('utf-8') + + return result + + +audio_base64 = encode_audio_base64_from_url(audio_url=audio_url) +chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + # Any format supported by librosa is supported + "url": f"data:audio/ogg;base64,{audio_base64}" + }, + }, + ], + }], + model=model, + max_tokens=64, +) + +result = chat_completion_from_base64.choices[0].message.content +print(f"Chat completion output:{result}") diff --git a/tests/conftest.py b/tests/conftest.py index 08a2c8fcda021..ae362b228d9d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,14 +9,14 @@ from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union) +import numpy as np import pytest import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, - AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, +from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature) from vllm import LLM, SamplingParams @@ -216,8 +216,7 @@ def __init__( *, model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, - is_vision_model: bool = False, - is_encoder_decoder_model: bool = False, + auto_cls=AutoModelForCausalLM, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding] = identity, ) -> None: @@ -234,13 +233,6 @@ def __init__( device="cpu", ).to(dtype=torch_dtype)) else: - if is_vision_model: - auto_cls = AutoModelForVision2Seq - elif is_encoder_decoder_model: - auto_cls = AutoModelForSeq2SeqLM - else: - auto_cls = AutoModelForCausalLM - model_kwargs = model_kwargs if model_kwargs is not None else {} self.model = self.wrap_device( auto_cls.from_pretrained( @@ -432,6 +424,7 @@ def generate_greedy_logprobs_limit( max_tokens: int, num_logprobs: int, images: Optional[List[Image.Image]] = None, + audios: Optional[List[Tuple[np.ndarray, int]]] = None, **kwargs: Any, ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: all_logprobs: List[List[Dict[int, float]]] = [] @@ -446,6 +439,11 @@ def generate_greedy_logprobs_limit( if images is not None and images[i] is not None: processor_kwargs["images"] = images[i] + if audios is not None: + audio, sr = audios[i] + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr + inputs = self.processor(**processor_kwargs) inputs = self.postprocess_inputs(inputs) @@ -627,6 +625,8 @@ def generate_w_logprobs( sampling_params: SamplingParams, images: Optional[Union[List[Image.Image], List[List[Image.Image]]]] = None, + audios: Optional[Union[List[Tuple[np.ndarray, int]], + List[List[Tuple[np.ndarray, int]]]]] = None ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None @@ -638,6 +638,10 @@ def generate_w_logprobs( for i, image in enumerate(images): inputs[i]["multi_modal_data"] = {"image": image} + if audios is not None: + for i, audio in enumerate(audios): + inputs[i]["multi_modal_data"] = {"audio": audio} + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) return self._final_steps_generate_w_logprobs(req_outputs) @@ -674,6 +678,8 @@ def generate_greedy_logprobs( num_logprobs: int, images: Optional[Union[List[Image.Image], List[List[Image.Image]]]] = None, + audios: Optional[Union[List[Tuple[np.ndarray, int]], + List[List[Tuple[np.ndarray, int]]]]] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, @@ -682,7 +688,8 @@ def generate_greedy_logprobs( stop_token_ids=stop_token_ids) outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, - images=images) + images=images, + audios=audios) return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py index 9850c823ff5da..f00d5ef584a2a 100644 --- a/tests/distributed/test_basic_distributed_correctness_enc_dec.py +++ b/tests/distributed/test_basic_distributed_correctness_enc_dec.py @@ -10,6 +10,7 @@ """ import pytest +from transformers import AutoModelForSeq2SeqLM from vllm.utils import cuda_device_count_stateless @@ -85,7 +86,7 @@ def test_models( } with hf_runner(model, dtype=dtype, - is_encoder_decoder_model=True) as hf_model: + auto_cls=AutoModelForSeq2SeqLM) as hf_model: hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( test_prompts, max_tokens, diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 39b47f3033715..6dc8dde667389 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -1,138 +1,36 @@ -import math -import sys -import time -from typing import Dict, List, Optional, Tuple, Union, cast -from unittest.mock import patch - -import librosa -import numpy as np +from typing import Dict, List + import openai import pytest -import requests -import torch - -from vllm import ModelRegistry -from vllm.config import MultiModalConfig -from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import LLMInputs -from vllm.inputs.registry import InputContext -from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) -from vllm.multimodal.utils import encode_audio_base64, fetch_audio -from vllm.utils import get_open_port -from ...utils import VLLM_PATH +from vllm.assets.audio import AudioAsset +from vllm.multimodal.utils import encode_audio_base64, fetch_audio -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() +from ...utils import RemoteOpenAIServer -MODEL_NAME = "facebook/opt-125m" +MODEL_NAME = "fixie-ai/ultravox-v0_3" TEST_AUDIO_URLS = [ - "https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg", + AudioAsset("winning_call").url, ] -def server_function(port): - - def fake_input_mapper(ctx: InputContext, data: object): - assert isinstance(data, tuple) - (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) - - # Resample it to 1 sample per second - audio = librosa.resample(audio, orig_sr=sr, target_sr=1) - return MultiModalInputs({"processed_audio": torch.from_numpy(audio)}) - - def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") - if multi_modal_data is None or "audio" not in multi_modal_data: - return llm_inputs - - audio, sr = multi_modal_data.get("audio") - audio_duration = math.ceil(len(audio) / sr) - - new_prompt, new_token_ids = repeat_and_pad_image_tokens( - cached_get_tokenizer(ctx.model_config.tokenizer), - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], - image_token_id=62, # "_" - repeat_count=audio_duration) - - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) - - @MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper) - @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", lambda *_, **__: 100) - @INPUT_REGISTRY.register_input_processor(fake_input_processor) - class FakeAudioModel(OPTForCausalLM, SupportsMultiModal): - - def __init__(self, *args, multimodal_config: MultiModalConfig, - **kwargs): - assert multimodal_config is not None - super().__init__(*args, **kwargs) - - def forward( - self, - *args, - processed_audio: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - return super().forward(*args, **kwargs) - - ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel) - - with patch( - "vllm.entrypoints.chat_utils._mm_token_str", - lambda *_, **__: "_"), patch( - "vllm.model_executor.models.ModelRegistry.is_multimodal_model" - ) as mock: - mock.return_value = True - sys.argv = ["placeholder.py"] + \ - (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 " - "--dtype bfloat16 --enforce-eager --api-key token-abc123 " - f"--port {port} --chat-template {chatml_jinja_path} " - "--disable-frontend-multiprocessing").split() - import runpy - runpy.run_module('vllm.entrypoints.openai.api_server', - run_name='__main__') +@pytest.fixture(scope="module") +def server(): + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") -def client(): - port = get_open_port() - ctx = torch.multiprocessing.get_context("spawn") - server = ctx.Process(target=server_function, args=(port, )) - server.start() - MAX_SERVER_START_WAIT_S = 60 - client = openai.AsyncOpenAI( - base_url=f"http://localhost:{port}/v1", - api_key="token-abc123", - ) - # run health check - health_url = f"http://localhost:{port}/health" - start = time.time() - while True: - try: - if requests.get(health_url).status_code == 200: - break - except Exception as err: - result = server.exitcode - if result is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError("Server failed to start in time.") from err - - try: - yield client - finally: - server.kill() +def client(server): + return server.get_async_client() @pytest.fixture(scope="session") @@ -176,7 +74,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=36, total_tokens=46) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message @@ -231,7 +129,7 @@ async def test_single_chat_session_audio_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=36, total_tokens=46) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 9bca5a86f1241..660b61d1a7ade 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -12,6 +12,7 @@ # (xFormers, etc.) import pytest + from transformers import AutoModelForSeq2SeqLM from vllm.sequence import SampleLogprobs @@ -131,7 +132,7 @@ def test_models( } with hf_runner(model, dtype=dtype, - is_encoder_decoder_model=True) as hf_model: + auto_cls=AutoModelForSeq2SeqLM) as hf_model: hf_outputs = ( hf_model.generate_encoder_decoder_greedy_logprobs_limit( test_case_prompts, diff --git a/tests/models/test_blip2.py b/tests/models/test_blip2.py index 64b7a77404b98..5d48bad0d7b35 100644 --- a/tests/models/test_blip2.py +++ b/tests/models/test_blip2.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple import pytest -from transformers import AutoTokenizer +from transformers import AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -80,7 +80,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_chameleon.py b/tests/models/test_chameleon.py index 5e7e0e6258f8a..e02b4b1ed72bd 100644 --- a/tests/models/test_chameleon.py +++ b/tests/models/test_chameleon.py @@ -1,7 +1,7 @@ from typing import List, Optional, Type import pytest -from transformers import BatchEncoding +from transformers import AutoModelForVision2Seq, BatchEncoding from vllm.multimodal.utils import rescale_image_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -74,7 +74,7 @@ def process(hf_inputs: BatchEncoding): with hf_runner(model, dtype=dtype, postprocess_inputs=process, - is_vision_model=True) as hf_model: + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index edaf7d400eb53..93634f245cee7 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,7 +1,8 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoConfig, AutoTokenizer, BatchEncoding +from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, + BatchEncoding) from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -124,7 +125,7 @@ def process(hf_inputs: BatchEncoding): with hf_runner(model, dtype=dtype, postprocess_inputs=process, - is_vision_model=True) as hf_model: + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava_image_embeds.py b/tests/models/test_llava_image_embeds.py index 63ccd1f6625c8..cc444fe32e79b 100644 --- a/tests/models/test_llava_image_embeds.py +++ b/tests/models/test_llava_image_embeds.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.sequence import SampleLogprobs @@ -105,7 +105,8 @@ def run_test( for prompts, images in vllm_inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 2bd27f888680d..9cf55c0858df0 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type, overload import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -129,7 +129,8 @@ def run_test( for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index 038a22f71acad..beddaaf608a18 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -102,7 +102,8 @@ def run_test( for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 03605e3b34810..0f974fcc1885c 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -26,7 +26,7 @@ def test_text_only_qwen_model( # for qwen-vl is still unsupported in VLLM. In the near-future, the # implementation and this test will be extended to consider # visual inputs as well. - with hf_runner(model, dtype=dtype, is_vision_model=False) as hf_model: + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, diff --git a/tests/models/test_ultravox.py b/tests/models/test_ultravox.py new file mode 100644 index 0000000000000..98de10aa08408 --- /dev/null +++ b/tests/models/test_ultravox.py @@ -0,0 +1,151 @@ +from typing import List, Optional, Tuple, Type + +import librosa +import numpy as np +import pytest +from transformers import AutoModel, AutoTokenizer, BatchEncoding + +from vllm.assets.audio import AudioAsset +from vllm.sequence import SampleLogprobs +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +from ..conftest import HfRunner, VllmRunner +from .utils import check_logprobs_close + +pytestmark = pytest.mark.vlm + +MODEL_NAME = "fixie-ai/ultravox-v0_3" + +AudioTuple = Tuple[np.ndarray, int] + + +@pytest.fixture(scope="session") +def audio_and_sample_rate(): + return AudioAsset("mary_had_lamb").audio_and_sample_rate + + +@pytest.fixture +def prompts_and_audios(audio_and_sample_rate): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + + vllm_placeholder = "<|reserved_special_token_0|>" + hf_placeholder = "<|audio|>" + + question = "What's in the audio?" + vllm_prompt = tokenizer.apply_chat_template( + [{ + 'role': 'user', + 'content': f"{vllm_placeholder}\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + hf_prompt = tokenizer.apply_chat_template( + [{ + 'role': 'user', + 'content': f"{hf_placeholder}\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + + return [(vllm_prompt, hf_prompt, audio_and_sample_rate)] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = output_ids[:] + hf_output_str = output_str + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts_and_audios: List[Tuple[str, str, AudioTuple]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm.""" + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_audio = [ + vllm_model.generate_greedy_logprobs([vllm_prompt], + max_tokens, + num_logprobs=num_logprobs, + audios=[audio]) + for vllm_prompt, _, audio in prompts_and_audios + ] + + def process(hf_inputs: BatchEncoding): + hf_inputs["audio_values"] = hf_inputs["audio_values"] \ + .to(torch_dtype) # type: ignore + return hf_inputs + + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + auto_cls=AutoModel) as hf_model: + + hf_outputs_per_audio = [ + hf_model.generate_greedy_logprobs_limit( + [hf_prompt], + max_tokens, + num_logprobs=num_logprobs, + audios=[(librosa.resample(audio[0], + orig_sr=audio[1], + target_sr=16000), 16000)]) + for _, hf_prompt, audio in prompts_and_audios + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio, + vllm_outputs_per_audio): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str, + max_tokens: int, num_logprobs: int) -> None: + run_test( + hf_runner, + vllm_runner, + prompts_and_audios, + MODEL_NAME, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py new file mode 100644 index 0000000000000..b00a61ebfec65 --- /dev/null +++ b/vllm/assets/audio.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Literal, Tuple +from urllib.parse import urljoin + +import librosa +import numpy as np + +from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL + +ASSET_DIR = "multimodal_asset" + + +@dataclass(frozen=True) +class AudioAsset: + name: Literal["winning_call", "mary_had_lamb"] + + @property + def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]: + + audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + return librosa.load(audio_path, sr=None) + + @property + def url(self) -> str: + return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 48fd1333d8f40..19d1095084293 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -117,8 +117,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, modality: Literal["image", "audio"]) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) + model_type = model_config.hf_config.model_type if modality == "image": - model_type = model_config.hf_config.model_type if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return "<|image_1|>" @@ -134,7 +134,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": - raise TypeError("No audio models are supported yet.") + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") else: raise TypeError(f"Unknown modality: {modality}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 32cafa845a6e3..bdf6e502ea112 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -61,7 +61,7 @@ "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), } _EMBEDDING_MODELS = { @@ -83,6 +83,7 @@ "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "UltravoxModel": ("ultravox", "UltravoxModel"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 69e777152e3d4..830680fd990bf 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -15,8 +15,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -97,11 +97,11 @@ def input_processor_for_blip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 788d22db9d5a8..a335e1766b2a9 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -30,8 +30,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SamplerOutput, SequenceData) from vllm.utils import print_warning_once @@ -124,11 +124,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=CHAMELEON_IMAGE_TOKEN_ID, + placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 24eeefdfccf00..0933966055330 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -16,8 +16,8 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -103,11 +103,11 @@ def input_processor_for_clip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 2ef23819b69a2..cfc2a5288a37b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -36,8 +36,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import (cached_get_image_processor, - cached_get_tokenizer) +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SamplerOutput, SequenceData) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b379c86c1912b..c996f0b73f293 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -23,7 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 99a3c5dab39e4..29f3640e2458b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -54,8 +54,8 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import (cached_get_image_processor, - cached_get_tokenizer) +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SamplerOutput, SequenceData) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8beb2778fe37a..8cb5065ed79ec 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -16,7 +16,7 @@ from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsMultiModal diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 328f4e6fa827c..9ccd6ef6d9ace 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -37,7 +37,7 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 426af7fee9544..7f6186fa010a4 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -112,11 +112,11 @@ def input_processor_for_siglip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py new file mode 100644 index 0000000000000..842264f765866 --- /dev/null +++ b/vllm/model_executor/models/ultravox.py @@ -0,0 +1,435 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py +"""PyTorch Ultravox model.""" + +import itertools +import math +from array import array +from functools import lru_cache +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union, cast) + +import librosa +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from transformers.models.whisper import WhisperFeatureExtractor +from transformers.models.whisper.modeling_whisper import WhisperEncoder + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.inputs.data import LLMInputs +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import (filter_weights, + init_vllm_registered_model, + merge_multimodal_embeddings) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData +from vllm.transformers_utils.configs.ultravox import UltravoxConfig + +_AUDIO_PLACEHOLDER_TOKEN = 128002 +_AUDIO_TOKENS_PER_SECOND = 6.25 + +logger = init_logger(__name__) + + +class UltravoxAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size, 80, M)""" + + +class UltravoxAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: torch.Tensor + + +UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, + UltravoxAudioEmbeddingInputs] + + +@lru_cache +def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor: + return WhisperFeatureExtractor.from_pretrained(model_id) + + +def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor: + return cached_feature_extractor( + ctx.get_hf_config(UltravoxConfig).audio_model_id) + + +def get_ultravox_max_audio_tokens(ctx: InputContext): + feature_extractor = whisper_feature_extractor(ctx) + return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) + + +def dummy_data_for_ultravox( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +): + feature_extractor = whisper_feature_extractor(ctx) + + audio_count = mm_counts["audio"] + + audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [ + _AUDIO_PLACEHOLDER_TOKEN + ]) * get_ultravox_max_audio_tokens(ctx) * audio_count + other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - len(audio_token_ids)) + + audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) + mm_dict = { + "audio": + audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count + } + + return (SequenceData(audio_token_ids + other_token_ids), mm_dict) + + +def input_mapper_for_ultravox(ctx: InputContext, data: object): + if isinstance(data, tuple): + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) + feature_extractor = whisper_feature_extractor(ctx) + + if sr != feature_extractor.sampling_rate: + audio = librosa.resample(audio, + orig_sr=sr, + target_sr=feature_extractor.sampling_rate) + sr = feature_extractor.sampling_rate + + minimum_audio_length = feature_extractor.n_fft // 2 + 1 + if len(audio) < minimum_audio_length: + # Not enough audio; pad it. + audio = np.pad(audio, (0, minimum_audio_length - len(audio))) + + return MultiModalInputs({ + "audio_features": + feature_extractor(audio, + sampling_rate=sr, + padding="longest", + return_tensors="pt")["input_features"] + }) + + raise NotImplementedError(f"Unsupported data type: {type(data)}") + + +def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "audio" not in multi_modal_data: + return llm_inputs + + feature_extractor = whisper_feature_extractor(ctx) + audio_data, sample_rate = multi_modal_data["audio"] + + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - + (feature_extractor.hop_length - 1)) / feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, + repeat_count=audio_num_tokens, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class StackAudioFrames(nn.Module): + """ + Stack the audio embedding frames to reduce the sequence length by a factor + of `stack_factor`. + """ + + def __init__(self, stack_factor: int = 8): + super().__init__() + self.stack_factor = stack_factor + + def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: + B, T, C = audio_embeds.shape + T_pad = (T + self.stack_factor - + 1) // self.stack_factor * self.stack_factor + audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T)) + B, T, C = audio_embeds.shape + audio_embeds = audio_embeds.view(B, T // self.stack_factor, + C * self.stack_factor) + return audio_embeds + + +class FlippedSiluAndMul(SiluAndMul): + """Ultravox is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + + +class UltravoxProjector(nn.Module): + + def __init__(self, config: UltravoxConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self._pad_and_stack = StackAudioFrames(config.stack_factor) + dim = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim) + self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) + dim = self.hidden_dim + + if config.projector_act == "swiglu": + self.act = FlippedSiluAndMul() + dim = dim // 2 + else: + self.act = get_act_fn(config.projector_act) + + self.linear_2 = nn.Linear(dim, + config.text_config.hidden_size, + bias=False) + self.ln_post = RMSNorm(config.text_config.hidden_size) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + audio_features = self._pad_and_stack(audio_features) + audio_features = self.ln_pre(audio_features) + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.ln_post(hidden_states) + return hidden_states + + +class ModifiedWhisperEncoder(WhisperEncoder): + """ + Encoder portion of OpenAI's Whisper model. + + This implementation is a slightly modified version of HF Transformers' + Whisper Encoder, with only a few fixes: + 1. base_model_prefix updated to allow for doing `.from_pretrained` + directly on the encoder + 2. allow less than 30 second of audio padding to be passed in: + - relaxed ValueError check for `input_features` length to be less + than or equal to `expected_seq_length` instead of strictly equal + - embed_pos is now sliced to match the length of `inputs_embeds` + + Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py + See commentary: https://github.com/huggingface/transformers/issues/25744 + """ + + base_model_prefix = "model.encoder" + + def forward( + self, + input_features, + ): + expected_seq_length = (self.config.max_source_positions * + self.conv1.stride[0] * self.conv2.stride[0]) + if input_features.shape[-1] > expected_seq_length: + raise ValueError( + f"Whisper expects the mel input features to be of length " + f"{expected_seq_length} or less, but found " + f"{input_features.shape[-1]}. Make sure to pad the input mel " + f"features to {expected_seq_length}.") + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=None, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_ultravox_max_audio_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) +@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) +class UltravoxModel(nn.Module, SupportsMultiModal): + + def __init__(self, + config: UltravoxConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional["QuantizationConfig"] = None): + super().__init__() + self.config = config + self.multi_modal_config = multimodal_config + assert self.multi_modal_config + + if config.audio_model_id is not None: + self.audio_tower = ModifiedWhisperEncoder.from_pretrained( + config.audio_model_id) + else: + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + self.multi_modal_projector = UltravoxProjector(config) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + def _audio_features_to_embeddings( + self, input_features: torch.Tensor) -> torch.Tensor: + audio_input = input_features.to(self.audio_tower.dtype) + audio_features = self.audio_tower(audio_input) + audio_features = audio_features.to(self.audio_tower.dtype) + audio_embeddings = self.multi_modal_projector(audio_features) + return audio_embeddings + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[UltravoxAudioInputs]: + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") + + return UltravoxAudioFeatureInputs(type="audio_features", + data=audio_features) + + if audio_embeds is not None: + if not isinstance(audio_embeds, torch.Tensor): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + + return UltravoxAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: UltravoxAudioInputs + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + audio_features = audio_input["data"] + if isinstance(audio_features, list): + # TODO: Batch these through the encoder/projector instead of + # serializing them. + return [ + self._audio_features_to_embeddings( + features.unsqueeze(0)).squeeze(0) + for features in audio_features + ] + else: + return self._audio_features_to_embeddings(audio_features) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[torch.Tensor], + **kwargs) -> SamplerOutput: + """Run forward pass for Ultravox + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted audio embeddings. The to-be-inserted + audio has a size that is essentially 6.25 tokens per second of audio. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_features: A batch of audio inputs, [1, 80, M]. + """ + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is not None: + audio_embeddings = self._process_audio_input(audio_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, audio_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + projector_weights, llm_weights = itertools.tee(weights, 2) + + # load projector weights + projector_weights = filter_weights(projector_weights, + "multi_modal_projector") + projector_params_dict = dict( + self.multi_modal_projector.named_parameters()) + for name, loaded_weight in projector_weights: + param = projector_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 916bd5e601bb7..6cdde949bc2b1 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import List, Optional, Tuple, TypeVar import torch from PIL import Image @@ -8,7 +7,6 @@ from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs, MultiModalPlugin @@ -16,87 +14,6 @@ logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) -cached_get_tokenizer = lru_cache(get_tokenizer) - -# Utilities for image input processors -_T = TypeVar("_T", str, int) - - -def repeat_and_pad_token( - token: _T, - *, - repeat_count: int = 1, - pad_token_left: Optional[_T] = None, - pad_token_right: Optional[_T] = None, -) -> List[_T]: - replacement = [token] * repeat_count - if pad_token_left is not None: - replacement = [pad_token_left] + replacement - if pad_token_right is not None: - replacement = replacement + [pad_token_right] - - return replacement - - -def repeat_and_pad_image_tokens( - tokenizer: AnyTokenizer, - prompt: Optional[str], - prompt_token_ids: List[int], - *, - image_token_id: int, - repeat_count: int = 1, - pad_token_left: Optional[int] = None, - pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int]]: - if prompt is None: - new_prompt = None - else: - image_token_str = tokenizer.decode(image_token_id) - pad_token_str_left = (None if pad_token_left is None else - tokenizer.decode(pad_token_left)) - pad_token_str_right = (None if pad_token_right is None else - tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - image_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) - - image_token_count = prompt.count(image_token_str) - # This is an arbitrary number to distinguish between the two cases - if image_token_count > 16: - logger.warning( - "Please follow the prompt format that is " - "documented on HuggingFace which does not involve " - "repeating %s tokens.", image_token_str) - elif image_token_count > 1: - logger.warning("Multiple image input is not supported yet, " - "so any extra image tokens will be treated " - "as plain text.") - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(image_token_str, replacement_str, 1) - - new_token_ids: List[int] = [] - for i, token in enumerate(prompt_token_ids): - if token == image_token_id: - replacement_ids = repeat_and_pad_token( - image_token_id, - repeat_count=repeat_count, - pad_token_left=pad_token_left, - pad_token_right=pad_token_right, - ) - new_token_ids.extend(replacement_ids) - - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + 1:]) - break - else: - new_token_ids.append(token) - - return new_prompt, new_token_ids class ImagePlugin(MultiModalPlugin): diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d1e624cdb8ace..3bf430235462b 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,6 +1,7 @@ import base64 +from functools import lru_cache from io import BytesIO -from typing import Tuple, Union +from typing import List, Optional, Tuple, TypeVar, Union import librosa import numpy as np @@ -9,7 +10,13 @@ from vllm.connections import global_http_connection from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT +from vllm.logger import init_logger from vllm.multimodal.base import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +logger = init_logger(__name__) + +cached_get_tokenizer = lru_cache(get_tokenizer) def _load_image_from_bytes(b: bytes): @@ -154,3 +161,84 @@ def rescale_image_size(image: Image.Image, if transpose >= 0: image = image.transpose(Image.Transpose(transpose)) return image + + +# Utilities for input processors +_T = TypeVar("_T", str, int) + + +def repeat_and_pad_token( + token: _T, + *, + repeat_count: int = 1, + pad_token_left: Optional[_T] = None, + pad_token_right: Optional[_T] = None, +) -> List[_T]: + replacement = [token] * repeat_count + if pad_token_left is not None: + replacement = [pad_token_left] + replacement + if pad_token_right is not None: + replacement = replacement + [pad_token_right] + + return replacement + + +def repeat_and_pad_placeholder_tokens( + tokenizer: AnyTokenizer, + prompt: Optional[str], + prompt_token_ids: List[int], + *, + placeholder_token_id: int, + repeat_count: int = 1, + pad_token_left: Optional[int] = None, + pad_token_right: Optional[int] = None, +) -> Tuple[Optional[str], List[int]]: + if prompt is None: + new_prompt = None + else: + placeholder_token_str = tokenizer.decode(placeholder_token_id) + pad_token_str_left = (None if pad_token_left is None else + tokenizer.decode(pad_token_left)) + pad_token_str_right = (None if pad_token_right is None else + tokenizer.decode(pad_token_right)) + replacement_str = "".join( + repeat_and_pad_token( + placeholder_token_str, + repeat_count=repeat_count, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + + placeholder_token_count = prompt.count(placeholder_token_str) + # This is an arbitrary number to distinguish between the two cases + if placeholder_token_count > 16: + logger.warning( + "Please follow the prompt format that is " + "documented on HuggingFace which does not involve " + "repeating %s tokens.", placeholder_token_str) + elif placeholder_token_count > 1: + logger.warning("Multiple multi-modal input is not supported yet, " + "so any extra placeholder tokens will be treated " + "as plain text.") + + # The image tokens are removed to be consistent with HuggingFace + new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) + + new_token_ids: List[int] = [] + for i, token in enumerate(prompt_token_ids): + if token == placeholder_token_id: + replacement_ids = repeat_and_pad_token( + placeholder_token_id, + repeat_count=repeat_count, + pad_token_left=pad_token_left, + pad_token_right=pad_token_right, + ) + new_token_ids.extend(replacement_ids) + + # No need to further scan the list since we only replace once + new_token_ids.extend(prompt_token_ids[i + 1:]) + break + else: + new_token_ids.append(token) + + return new_prompt, new_token_ids diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5f04b39ef524e..d3024965c0b4c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -12,7 +12,7 @@ InternVLChatConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, - RWConfig) + RWConfig, UltravoxConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -32,6 +32,7 @@ "medusa": MedusaConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, + "ultravox": UltravoxConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 5ccacd4a4c40a..22b906a3149ec 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ "ChatGLMConfig", @@ -21,4 +22,5 @@ "MedusaConfig", "MLPSpeculatorConfig", "NemotronConfig", + "UltravoxConfig", ] diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py new file mode 100644 index 0000000000000..f724bf7f2f1cd --- /dev/null +++ b/vllm/transformers_utils/configs/ultravox.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py +from typing import Any, Dict, Optional + +import transformers + + +class UltravoxConfig(transformers.PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`UltravoxForConditionalGeneration`]. It is used to instantiate an + Ultravox model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + Args: + audio_config (`Union[AutoConfig, dict]`, *optional*): + Custom audio config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` + or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + audio_token_index (`int`, *optional*, defaults to 32000): + The audio token index to encode the audio prompt. + stack_factor (`int`, *optional*, defaults to 8): + Audio downsampling factor for the multimodal projector. + norm_init (`float`, *optional*, defaults to 0.4): + The initialization value for the layer normalization. + projector_act (`str`, *optional*, defaults to `"swiglu"`): + The activation function used by the multimodal projector. + text_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the text model. + audio_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the audio model. + """ + + model_type = "ultravox" + is_composition = False + + def __init__( + self, + audio_config: Optional[Dict[str, Any]] = None, + text_config: Optional[Dict[str, Any]] = None, + audio_model_id: Optional[str] = None, + text_model_id: Optional[str] = None, + ignore_index: int = -100, + audio_token_index: int = 32000, + hidden_size: int = 4096, + stack_factor: int = 8, + norm_init: float = 0.4, + projector_act: str = "swiglu", + text_model_lora_config: Optional[Dict[str, Any]] = None, + audio_model_lora_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + self.ignore_index = ignore_index + + self.audio_model_id = audio_model_id + self.text_model_id = text_model_id + self.audio_token_index = audio_token_index + + self.hidden_size = hidden_size + self.stack_factor = stack_factor + self.norm_init = norm_init + self.projector_act = projector_act + + if text_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.text_config = get_config(text_model_id, + trust_remote_code=False) + else: + text_config = text_config or {} + self.text_config = transformers.CONFIG_MAPPING[text_config.get( + "model_type", "llama")](**text_config) + + if audio_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(audio_model_id, + trust_remote_code=False) + else: + audio_config = audio_config or {} + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( + "model_type", "whisper")](**audio_config) + + self.text_model_lora_config = text_model_lora_config or {} + self.audio_model_lora_config = audio_model_lora_config or {} + + self.vocab_size = self.text_config.vocab_size + + self.initializer_range = self.text_config.initializer_range + + super().__init__(**kwargs) From 5844017285acda7060ffc62e3dcedc0775eb4fe2 Mon Sep 17 00:00:00 2001 From: William Lin Date: Wed, 21 Aug 2024 15:52:40 -0700 Subject: [PATCH 08/16] [ci] [multi-step] narrow multi-step test dependency paths (#7760) --- .buildkite/test-pipeline.yaml | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa90145705f9d..4d1a997f6425f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -312,12 +312,20 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py -- label: Multi-step Tests (4 GPUs) # 10min +- label: Multi-step Tests (4 GPUs) # 21min working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: - - vllm/ - - tests/multi_step/test_correctness.py + - vllm/model_executor/layers/sampler.py + - vllm/sequence.py + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/multi_step_worker.py + - vllm/worker/model_runner_base.py + - vllm/worker/model_runner.py + - vllm/worker/multi_step_model_runner.py + - vllm/engine + - tests/multi_step commands: - pytest -v -s multi_step/test_correctness.py From 8678a69ab51956031e3bb70bdf1a781a8652e67d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 21 Aug 2024 19:17:10 -0400 Subject: [PATCH 09/16] [Kernel] Expand MoE weight loading + Add Fused Marlin MoE Kernel (#7527) Co-authored-by: ElizaWszola --- CMakeLists.txt | 3 +- csrc/moe/marlin_moe_ops.cu | 1740 +++++++++++++++++ csrc/moe/marlin_moe_ops.h | 12 + csrc/moe/torch_bindings.cpp | 9 + tests/weight_loading/models.txt | 2 + vllm/_custom_ops.py | 14 + .../layers/fused_moe/__init__.py | 14 +- .../layers/fused_moe/fused_moe.py | 134 +- vllm/model_executor/layers/fused_moe/layer.py | 206 +- .../compressed_tensors/compressed_tensors.py | 5 + .../compressed_tensors_moe.py | 283 +++ .../model_executor/layers/quantization/fp8.py | 29 +- vllm/model_executor/model_loader/utils.py | 4 +- vllm/model_executor/models/jamba.py | 2 +- vllm/model_executor/models/mixtral.py | 1 + 15 files changed, 2374 insertions(+), 84 deletions(-) create mode 100644 csrc/moe/marlin_moe_ops.cu create mode 100644 csrc/moe/marlin_moe_ops.h create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 217dc70c4b24e..18e5109919104 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -286,7 +286,8 @@ define_gpu_extension_target( set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" - "csrc/moe/topk_softmax_kernels.cu") + "csrc/moe/topk_softmax_kernels.cu" + "csrc/moe/marlin_moe_ops.cu") define_gpu_extension_target( _moe_C diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu new file mode 100644 index 0000000000000..1e170e80d2f70 --- /dev/null +++ b/csrc/moe/marlin_moe_ops.cu @@ -0,0 +1,1740 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / blockDim.x; + int rest = size_k % blockDim.x; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += blockDim.x; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; + expert_offsets[i + 1] = tot_offset; + } + } + __syncthreads(); +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + constexpr int sorted_sh_stride = threads; + constexpr int sorted_gl_stride = threads; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (group_blocks == -1 || group_blocks == 0) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int* sh_sorted = (int*)(sh_s + shs_size); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // TODO we are currently hitting illegal memory accesses when fetching + // sorted_ids to shared data: fix this + auto fetch_sorted_ids_to_shared = [&]() { + const int mpt = ceildiv(prob_m, threads); + for (int i = 0; i < mpt; i++) { + if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { + sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = + sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + } + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + + FragB frag_b0 = dequant(b_quant); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + FragB frag_b1 = dequant(b_quant_shift); + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + __syncthreads(); + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par // maximum parallelism +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / 64; + par = min((16 * max_block - pad) / 64, max_par); + prob_m = 64 * par; + m_block_ctr += 4 * (par - 1); + max_block = 4; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par // maximum parallelism +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + + " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + + str(prob_n) + "]"); + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int tot_m = prob_m; + + const int* topk_ids_ptr = (const int*)topk_ids; + int* expert_offsets_ptr = (int*)expert_offsets; + compute_expert_offsets<<<1, num_experts, 0, stream>>>( + topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); + + bool do_permute_a = has_act_order; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const int4* A_ptr = (const int4*)A; + int4* a_tmp_ptr = (int4*)a_tmp; + const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + int4* C_ptr = (int4*)C; + const float* topk_weights_ptr = (const float*)topk_weights; + const int* sorted_ids_ptr = (const int*)sorted_ids; + const int4* s_ptr = + (const int4*)s + + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * + prob_n / 8) * + expert_idx; + const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; + const int* perm_ptr = (const int*)perm + prob_k * expert_idx; + int* locks = (int*)workspace; + + if (do_permute_a) { + // Permute A columns + int topk_rows = replicate_input ? tot_m : tot_m * topk; + int block_rows = ceildiv(topk_rows, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + int max_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { + // Define kernel configurations + + // make it max possible value + int thread_m_blocks = 4; + + if (false) { + } + CALL_IF_MOE(16, 4, 256) + CALL_IF_MOE(8, 8, 256) + CALL_IF_MOE(8, 4, 128) + CALL_IF_MOE(4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + } + } +} + +} // namespace marlin_moe + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + const torch::Tensor& g_idx, const torch::Tensor& perm, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights) { + int max_par = 4; + + int dev = a.get_device(); + + auto options_dtype = + torch::TensorOptions().dtype(a.dtype()).device(a.device()); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); + torch::Tensor a_tmp = + replicate_input ? torch::zeros({size_m, size_k}, options_dtype) + : torch::zeros({size_m, topk, size_k}, options_dtype); + torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(1) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + marlin_moe::marlin_mm_moe_f16i4( + a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), + topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), + has_act_order, is_k_full, num_groups, group_size, num_experts, topk, + moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, max_par, replicate_input, apply_weights); + return c; +} \ No newline at end of file diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h new file mode 100644 index 0000000000000..01ba8ff69850d --- /dev/null +++ b/csrc/moe/marlin_moe_ops.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + const torch::Tensor& g_idx, const torch::Tensor& perm, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights); \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 86e42af44df15..cda1405b4e4f1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,5 +1,6 @@ #include "core/registration.h" #include "moe_ops.h" +#include "marlin_moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. @@ -7,6 +8,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " + "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " + "bool replicate_input, bool apply_weights) -> Tensor"); + + m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 064dbb1feee83..c074b4b44c768 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -13,5 +13,7 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b89a90ef0f70c..ae90af563c0cf 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,6 +300,20 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * 2), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], + size_k, size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3e0767c7d2665..fd6f41b90042e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,19 +1,17 @@ -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -__all__ = [ - "FusedMoE", - "FusedMoEMethodBase", -] +__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + fused_experts, fused_marlin_moe, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ + "fused_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bcf25d2631042..d2b152320e11e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,21 +323,16 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: +def get_default_config(M: int, E: int, N: int, K: int, topk: int, + dtype: Optional[str], + is_marlin: bool) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } - if M <= E: + if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, @@ -347,14 +342,14 @@ def get_default_config( return config -def try_get_optimal_moe_config( - w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, Any]] = None, -): +def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, + Any]] = None, + is_marlin: bool = False): if override_config: config = override_config else: @@ -368,7 +363,8 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin) return config @@ -441,6 +437,108 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids +def fused_marlin_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + #TODO fp8 is not implemented yet + assert not use_fp8 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + get_config_func = functools.partial(try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, + g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, + block_size_m, True, False) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, + w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, + block_size_m, False, True) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) + + def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4e29ab701b937..3a77bf30131f9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from enum import Enum from typing import List, Optional, Tuple import torch @@ -15,6 +16,12 @@ logger = init_logger(__name__) +class FusedMoeWeightScaleSupported(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + + class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod @@ -199,55 +206,182 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) + def _load_per_tensor_weight_scale(self, shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale(self, shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # Load grouped weight scales for group quantization + # or model weights + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + shard_dim: int, shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + def _load_single_value(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + param_data[expert_id] = loaded_weight + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") - # Special case for fp8 scales. - if getattr(param, "is_fp8_scale", False): - self._load_fp8_scale(param.data, loaded_weight, weight_name, - shard_id, expert_id) - return + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # If transposed, weight is saved as [input_dim, output_dim] - # Otherwise, weight is saved as [output_dim, input_dim] - # Default is not transposed/input dim is dim 1 - input_dim = getattr(param, "input_dim", 1) - output_dim = getattr(param, "output_dim", 0) + # is_transposed: whether or not the parameter is transposed on disk + # If transposed, the loaded weight will be transposed and the dim + # to shard the loaded weight will be flipped. + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + loaded_weight = loaded_weight.t().contiguous() + shard_dim = ~shard_dim + + # Case weight_scales + if "weight_scale" in weight_name: + # load the weight scaling based on the quantization scheme + # supported weight scales can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", None) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + return - # Index the loaded weight for tp sharding. - # down_proj: "RowParallel" so tp sharding on input_dim - if shard_id == "w2": - shard_dim = input_dim - shard_size = expert_data.shape[shard_dim] - # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - elif shard_id in ("w1", "w3"): - shard_dim = output_dim - shard_size = expert_data.shape[output_dim] // 2 - offset = shard_size * tp_rank - loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) + if "weight_shape" in weight_name: + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return - # Narrow parameter and load. - # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) - expert_data.copy_(loaded_weight) - # w3, up_proj: Load into second logical weight of w13. - elif shard_id == "w3": - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - expert_data.copy_(loaded_weight) - # w2, down_proj: Load into only logical weight of w2. - elif shard_id == "w2": - expert_data.copy_(loaded_weight) - else: - raise ValueError( - f"Expected shard_id w1,w2 or w3 but got {shard_id}") + # Case input scale + if "input_scale" in weight_name: + # Note: input_scale loading is only supported for fp8 + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return @staticmethod def select_experts(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ae75781927381..759dd9c0dd4ef 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -3,9 +3,12 @@ import torch from pydantic import BaseModel +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsUnquantized, @@ -64,6 +67,8 @@ def get_quant_method( return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod(self) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 0000000000000..0e0ab9ce9169f --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,283 @@ +import enum +from enum import Enum +from typing import List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + CompressionFormat) +from vllm.model_executor.utils import set_weight_attrs + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = ["CompressedTensorsMoEMethod"] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy.value + self.group_size = config.group_size + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size // + self.packed_factor, + 2 * intermediate_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + intermediate_size // + self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = intermediate_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int, num_bits: int): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, + scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, + size_n: int, group_size: int, + num_bits: int): + num_experts = s.shape[0] + output = torch.empty((num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, + group_size, num_bits) + return output + + size_k2 = layer.w2_weight_packed.shape[2] + size_k13 = layer.w13_weight_packed.shape[2] + + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + size_k13, + layer.w13_weight_scale.shape[2], + self.group_size, + self.num_bits, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] * self.packed_factor, + size_k2, + self.group_size, + self.num_bits, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_marlin_moe) + + return fused_marlin_moe(x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fd7682a1c0f51..7f45a20bd9dd9 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -318,19 +319,16 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) - + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) - set_weight_attrs(w2_weight_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -343,19 +341,14 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: layer.w13_input_scale = None layer.w2_input_scale = None diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 331b859d2adec..4bb943ab3afe4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,11 +23,11 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. + mixtral_supported = ["fp8", "compressed-tensors"] if (model_config.quantization is not None - and model_config.quantization != "fp8" + and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - return ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b82eb14fb5f23..caeda4e42d8a0 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -920,7 +920,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 34f581ac78582..413783ba4b259 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -73,6 +73,7 @@ def __init__(self, self.hidden_size = hidden_size # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False, From 7eebe8ccaa8bb9c37d59d00cbedcd5e67308acfe Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 21 Aug 2024 16:25:34 -0700 Subject: [PATCH 10/16] [distributed][misc] error on same VLLM_HOST_IP setting (#7756) --- vllm/envs.py | 5 ++++- vllm/executor/ray_gpu_executor.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index e4cf6a028ac18..4f7a7ad7821d5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -137,7 +137,10 @@ def get_default_config_root(): os.path.join(get_default_cache_root(), "vllm"), )), - # used in distributed environment to determine the master address + # used in distributed environment to determine the ip address + # of the current node, when the node has multiple network interfaces. + # If you are using multi-node inference, you should set this differently + # on each node. 'VLLM_HOST_IP': lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index aec6998d4488d..760c06cb6c06f 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -218,6 +218,19 @@ def sort_by_driver_then_worker_ip(worker): for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP` or " + "`HOST_IP` environment variable, make sure it is unique for" + " each node.") + VLLM_INSTANCE_ID = get_vllm_instance_id() # Set environment variables for the driver and workers. From 9984605412de1171a72d955cfcb954725edd4d6f Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 21 Aug 2024 19:47:36 -0400 Subject: [PATCH 11/16] [AMD][CI/Build] Disambiguation of the function call for ROCm 6.2 headers compatibility (#7477) Co-authored-by: Charlie Fu --- csrc/attention/attention_utils.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index cdcee42748998..826b0edffae67 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -34,7 +34,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { A_vec qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); } // Finalize the reduction across lanes. From 7937009a7e82c3c4c9c7f48d11142bee5aac4a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Wed, 21 Aug 2024 20:18:00 -0400 Subject: [PATCH 12/16] [Kernel] Replaced `blockReduce[...]` functions with `cub::BlockReduce` (#7233) Co-authored-by: Michael Goin --- .../configs/Meta-Llama-3-8B-QQQ.yaml | 4 +- benchmarks/kernels/benchmark_layernorm.py | 89 +++++++++++++++ benchmarks/kernels/benchmark_quant.py | 103 ++++++++++++++++++ csrc/layernorm_kernels.cu | 33 +++--- .../compressed_tensors/int8_quant_kernels.cu | 14 ++- csrc/quantization/fp8/common.cu | 13 ++- csrc/reduction_utils.cuh | 95 ---------------- .../basic_correctness/test_chunked_prefill.py | 2 +- 8 files changed, 237 insertions(+), 116 deletions(-) create mode 100644 benchmarks/kernels/benchmark_layernorm.py create mode 100644 benchmarks/kernels/benchmark_quant.py delete mode 100644 csrc/reduction_utils.cuh diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml index c457468902c98..0424586598391 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -4,8 +4,8 @@ tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.409 + value: 0.419 - name: "exact_match,flexible-extract" - value: 0.406 + value: 0.416 limit: 1000 num_fewshot: 5 diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py new file mode 100644 index 0000000000000..4947fda02e1cc --- /dev/null +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -0,0 +1,89 @@ +import random +import time + +import torch + +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device("cuda") + + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + layer(x, residual) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser( + description="Benchmark the layernorm kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--add-residual", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py new file mode 100644 index 0000000000000..4c1a7b26213a5 --- /dev/null +++ b/benchmarks/kernels/benchmark_quant.py @@ -0,0 +1,103 @@ +import random +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device("cuda") + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if quant_dtype == torch.int8: + ops.scaled_int8_quant(x, scale) + else: + ops.scaled_fp8_quant(x, scale) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError(f"Unsupported dtype: {dt}") + + parser = FlexibleArgumentParser( + description="Benchmark the quantization (fp8 or int8) kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--static-scale", action="store_true") + parser.add_argument("--quant-dtype", + type=str, + choices=["fp8", "int8"], + default="int8") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index ca1c04bd880d9..7a7a25d2173d2 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -3,13 +3,16 @@ #include #include "dispatch_utils.h" -#include "reduction_utils.cuh" #ifndef USE_ROCM #include #include + #include + #include #else #include #include + #include + #include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; @@ -31,7 +34,11 @@ __global__ void rms_norm_kernel( const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -228,12 +235,11 @@ fused_add_rms_norm_kernel( variance += temp.sum_squares(); residual_v[id] = temp; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -268,12 +274,11 @@ fused_add_rms_norm_kernel( variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aa9511daa2772..616fc149760e5 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -3,7 +3,14 @@ #include #include "../../dispatch_utils.h" -#include "../../reduction_utils.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel( absmax_val = val > absmax_val ? val : absmax_val; } - float const block_absmax_val_maybe = blockReduceMax(absmax_val); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float block_absmax_val; if (tid == 0) { block_absmax_val = block_absmax_val_maybe; diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 3f77c76ae7ec4..7e23f92257769 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,7 +7,13 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#include "../../reduction_utils.cuh" +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif #ifndef USE_ROCM using FP8_TYPE = c10::Float8_e4m3fn; @@ -215,7 +221,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } } - float const block_absmax_val_maybe = blockReduceMax(absmax_val); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { if (scale_ub) { diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh deleted file mode 100644 index 08063356012b8..0000000000000 --- a/csrc/reduction_utils.cuh +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cuda_compat.h" - -namespace vllm { - -namespace detail { - -template -__inline__ __device__ T _max(T a, T b) { - return max(a, b); -} - -template -__inline__ __device__ T _sum(T a, T b) { - return a + b; -} - -} // namespace detail - -template -using ReduceFnType = T (*)(T, T); - -// Helper function to return the next largest power of 2 -static constexpr int _nextPow2(unsigned int num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -template -__inline__ __device__ T warpReduce(T val, ReduceFnType fn) { - static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, - "numLanes is not a positive power of 2!"); - static_assert(numLanes <= WARP_SIZE); -#pragma unroll - for (int mask = numLanes >> 1; mask > 0; mask >>= 1) - val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); - - return val; -} - -template -__inline__ __device__ T blockReduce(T val, ReduceFnType fn) { - static_assert(maxBlockSize <= 1024); - if constexpr (maxBlockSize > WARP_SIZE) { - val = warpReduce(val, fn); - // Calculates max number of lanes that need to participate in the last - // warpReduce - constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; - static __shared__ T shared[maxActiveLanes]; - int lane = threadIdx.x % WARP_SIZE; - int wid = threadIdx.x / WARP_SIZE; - if (lane == 0) shared[wid] = val; - - __syncthreads(); - - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] - : (T)(0.0f); - val = warpReduce(val, fn); - } else { - // A single warpReduce is equal to blockReduce - val = warpReduce(val, fn); - } - return val; -} - -template -__inline__ __device__ T blockReduceMax(T val) { - return blockReduce(val, detail::_max); -} - -template -__inline__ __device__ T blockReduceSum(T val) { - return blockReduce(val, detail::_sum); -} - -} // namespace vllm diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index fcc444842213a..9c6364ecc6792 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -83,7 +83,7 @@ def test_models( for m in E4M3_KV_MODELS]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16]) @pytest.mark.parametrize("enforce_eager", [False, True]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. From df1a21131d951ba8ee65363aeb9b9486f569aa4f Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 21 Aug 2024 18:36:24 -0700 Subject: [PATCH 13/16] [Model] Fix Phi-3.5-vision-instruct 'num_crops' issue (#7710) --- docs/source/models/supported_models.rst | 4 ++-- tests/models/test_phi3v.py | 2 +- vllm/config.py | 6 +++++- vllm/inputs/registry.py | 11 +++++++++-- vllm/model_executor/models/phi3v.py | 12 ++++++------ vllm/transformers_utils/config.py | 15 ++++++++++++++- 6 files changed, 37 insertions(+), 13 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 1692e13c4ec06..7a9c87f406c66 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -225,9 +225,9 @@ Multimodal Language Models - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - * - :code:`Phi3VForCausalLM` - - Phi-3-Vision + - Phi-3-Vision, Phi-3.5-Vision - Image - - :code:`microsoft/Phi-3-vision-128k-instruct`, etc. + - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - * - :code:`MiniCPMV` - MiniCPM-V diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index ccfc98a325982..197e63b1b1e52 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -21,7 +21,7 @@ "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", }) -models = ["microsoft/Phi-3-vision-128k-instruct"] +models = ["microsoft/Phi-3.5-vision-instruct"] def vllm_to_hf_output(vllm_output: Tuple[List[int], str, diff --git a/vllm/config.py b/vllm/config.py index 7e62a727115ef..4cbdde5e113a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,7 +13,9 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import get_config, get_hf_text_config +from vllm.transformers_utils.config import (get_config, + get_hf_image_processor_config, + get_hf_text_config) from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_cpu, is_hip, is_neuron, is_openvino, is_xpu, @@ -167,6 +169,8 @@ def __init__( self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision, rope_scaling, rope_theta) self.hf_text_config = get_hf_text_config(self.hf_config) + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) # Choose a default enforce_eager value if the user did not specify diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index deb66f0b0cb35..ae6c6c05d9f72 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,8 +2,8 @@ from array import array from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol, - Tuple, Type) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, + Protocol, Tuple, Type) from torch import nn from transformers import PretrainedConfig @@ -55,6 +55,13 @@ def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: return hf_config + def get_hf_image_processor_config(self) -> Dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + + return self.model_config.hf_image_processor_config + N = TypeVar("N", bound=Type[nn.Module]) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 9ccd6ef6d9ace..4854377215608 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -15,8 +15,8 @@ # limitations under the License. import re from functools import lru_cache -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union) +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict, Union) import numpy as np import torch @@ -324,12 +324,12 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181 def get_phi3v_image_feature_size( - hf_config: PretrainedConfig, + hf_config: Dict[str, Any], *, input_height: int, input_width: int, ) -> int: - num_crops = getattr(hf_config, "num_crops", 16) + num_crops = hf_config.get("num_crops", 16) new_width, new_height = _calc_hd_transform_size(width=input_width, height=input_height, hd_num=num_crops) @@ -341,7 +341,7 @@ def get_phi3v_image_feature_size( def get_max_phi3v_image_tokens(ctx: InputContext): return get_phi3v_image_feature_size( - ctx.get_hf_config(), + ctx.get_hf_image_processor_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, ) @@ -395,7 +395,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config - hf_config = ctx.get_hf_config() + hf_config = ctx.get_hf_image_processor_config() image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d3024965c0b4c..0f86b02deb21a 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,8 +1,10 @@ import contextlib from pathlib import Path -from typing import Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union from transformers import GenerationConfig, PretrainedConfig +from transformers.models.auto.image_processing_auto import ( + get_image_processor_config) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) @@ -98,6 +100,17 @@ def get_config( return config +def get_hf_image_processor_config( + model: Union[str, Path], + revision: Optional[str] = None, + **kwargs, +) -> Dict[str, Any]: + # Separate model folder from file path for GGUF models + if Path(model).is_file() and Path(model).suffix == ".gguf": + model = Path(model).parent + return get_image_processor_config(model, revision=revision, **kwargs) + + def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. From cde9183b40a88f4210a7e965a430ae860aba5f6d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 21 Aug 2024 20:18:11 -0600 Subject: [PATCH 14/16] [Bug][Frontend] Improve ZMQ client robustness (#7443) Signed-off-by: Joe Runde --- tests/entrypoints/openai/rpc/__init__.py | 0 .../entrypoints/openai/rpc/test_zmq_client.py | 119 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 5 +- vllm/entrypoints/openai/rpc/__init__.py | 4 - vllm/entrypoints/openai/rpc/client.py | 70 +++++++---- vllm/envs.py | 6 + 6 files changed, 176 insertions(+), 28 deletions(-) create mode 100644 tests/entrypoints/openai/rpc/__init__.py create mode 100644 tests/entrypoints/openai/rpc/test_zmq_client.py diff --git a/tests/entrypoints/openai/rpc/__init__.py b/tests/entrypoints/openai/rpc/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py new file mode 100644 index 0000000000000..631d15cd03ed7 --- /dev/null +++ b/tests/entrypoints/openai/rpc/test_zmq_client.py @@ -0,0 +1,119 @@ +import asyncio +import tempfile +import unittest +import unittest.mock +import uuid + +import pytest +import pytest_asyncio + +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, + RPCClientClosedError) +from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest_asyncio.fixture(scope="function") +async def dummy_server(tmp_socket, monkeypatch): + dummy_engine = unittest.mock.AsyncMock() + + def dummy_engine_builder(*args, **kwargs): + return dummy_engine + + with monkeypatch.context() as m: + m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) + server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) + + loop = asyncio.get_running_loop() + server_task = loop.create_task(server.run_server_loop()) + + try: + yield server + finally: + server_task.cancel() + server.cleanup() + + +@pytest_asyncio.fixture(scope="function") +async def client(tmp_socket): + client = AsyncEngineRPCClient(rpc_path=tmp_socket) + # Sanity check: the server is connected + await client._wait_for_server_rpc() + + try: + yield client + finally: + client.close() + + +@pytest.mark.asyncio +async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, + client: AsyncEngineRPCClient): + with monkeypatch.context() as m: + # Make the server _not_ reply with a model config + m.setattr(dummy_server, "get_config", lambda x: None) + m.setattr(client, "_data_timeout", 10) + + # And ensure the task completes anyway + # (client.setup() invokes server.get_config()) + client_task = asyncio.get_running_loop().create_task(client.setup()) + with pytest.raises(TimeoutError, match="Server didn't reply within"): + await asyncio.wait_for(client_task, timeout=0.05) + + +@pytest.mark.asyncio +async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, + client: AsyncEngineRPCClient): + with monkeypatch.context() as m: + # Hang all abort requests + m.setattr(dummy_server, "abort", lambda x: None) + m.setattr(client, "_data_timeout", 10) + + # Ensure the client doesn't hang + client_task = asyncio.get_running_loop().create_task( + client.abort("test request id")) + with pytest.raises(TimeoutError, match="Server didn't reply within"): + await asyncio.wait_for(client_task, timeout=0.05) + + +@pytest.mark.asyncio +async def test_client_data_methods_reraise_exceptions( + monkeypatch, dummy_server, client: AsyncEngineRPCClient): + with monkeypatch.context() as m: + # Make the server raise some random exception + exception = RuntimeError("Client test exception") + + def raiser(): + raise exception + + m.setattr(dummy_server.engine, "get_model_config", raiser) + m.setattr(client, "_data_timeout", 10) + + client_task = asyncio.get_running_loop().create_task(client.setup()) + # And ensure the task completes, raising the exception + with pytest.raises(RuntimeError, match=str(exception)): + await asyncio.wait_for(client_task, timeout=0.05) + + +@pytest.mark.asyncio +async def test_client_errors_after_closing(monkeypatch, dummy_server, + client: AsyncEngineRPCClient): + + client.close() + + # Healthchecks and generate requests will fail with explicit errors + with pytest.raises(RPCClientClosedError): + await client.check_health() + with pytest.raises(RPCClientClosedError): + async for _ in client.generate(None, None, None): + pass + + # But no-ops like aborting will pass + await client.abort("test-request-id") + await client.do_log_stats() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8e8371ef1559a..603ac19d8c04b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -6,7 +6,7 @@ import re import tempfile from argparse import Namespace -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from http import HTTPStatus from typing import AsyncIterator, Optional, Set @@ -83,7 +83,8 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await async_engine_client.do_log_stats() + with suppress(Exception): + await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 571dca5f61fa4..efc7e43afdcc9 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -10,10 +10,6 @@ # Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" -# Timeouts. -VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000 -VLLM_RPC_HEALTH_TIMEOUT_MS = 10000 - # Minimum value of ZMQ.SOCKET_LIMIT to run mp. VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 1f26348c74d6d..55b92d41975ea 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,5 +1,5 @@ import asyncio -from contextlib import contextmanager +from contextlib import contextmanager, suppress from typing import Any, AsyncGenerator, Mapping, Optional from uuid import uuid4 @@ -11,13 +11,12 @@ ParallelConfig, SchedulerConfig) # yapf: disable from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_HEALTH_TIMEOUT_MS, - VLLM_RPC_SERVER_START_TIMEOUT_MS, VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) # yapf: enable +from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -32,6 +31,17 @@ INPROC_PROXY_PATH = f"inproc://{uuid4()}" +class RPCClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + class AsyncEngineRPCClient: """ RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. @@ -85,6 +95,8 @@ class AsyncEngineRPCClient: def __init__(self, rpc_path: str): self.context = zmq.asyncio.Context() + self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS + self._errored = False # Maximum number of sockets that can be opened (typically 65536). # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) @@ -143,7 +155,6 @@ async def setup(self): # Wait until server is ready. await self._wait_for_server_rpc() - self._errored = False # Get the configs. self.model_config = await self._get_model_config_rpc() @@ -170,6 +181,15 @@ def close(self): @contextmanager def to_proxy_socket(self): # Connect to the RPCServer via the proxy. + + # Raise a sensible error if the client was already closed. + # This can happen if a server shutdown is triggered but some coroutines + # are still running requests. + # There should not be a race condition with this check because we don't + # yield to the event loop between here and opening the socket. + if self.context.closed: + raise RPCClientClosedError("The ZMQ client has already shut down") + # Note that we use DEALER to enable asynchronous communication # to enable streaming. socket = self.context.socket(zmq.constants.DEALER) @@ -189,9 +209,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, # Ping RPCServer with a request. await socket.send_multipart([cloudpickle.dumps(request)]) + # Make sure the server responds + if await socket.poll(timeout=self._data_timeout) == 0: + raise TimeoutError("Server didn't reply within " + f"{self._data_timeout} ms") + # Await the data from the Server. data = cloudpickle.loads(await socket.recv()) + if isinstance(data, Exception): + # Re-raise exceptions returned by the server + raise data + if not isinstance(data, expected_type): # LoRAConfig can be None. if expected_type == LoRAConfig and data is None: @@ -208,29 +237,28 @@ async def _send_one_way_rpc_request( self, request: RPC_REQUEST_TYPE, error_message: str, - timeout: Optional[int] = None, socket: Optional[zmq.asyncio.Socket] = None): """Send one-way RPC request to trigger an action.""" async def do_rpc_call(socket: zmq.asyncio.Socket, - request: RPC_REQUEST_TYPE, - timeout=None): + request: RPC_REQUEST_TYPE): await socket.send_multipart([cloudpickle.dumps(request)]) - if timeout is not None and await socket.poll(timeout=timeout) == 0: - raise TimeoutError(f"Server didn't reply within {timeout} ms") + if await socket.poll(timeout=self._data_timeout) == 0: + raise TimeoutError("Server didn't reply within " + f"{self._data_timeout} ms") return cloudpickle.loads(await socket.recv()) # Make a new socket connection. if socket is None: with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request, timeout) + response = await do_rpc_call(socket, request) # Use existing socket connection. else: - response = await do_rpc_call(socket, request, timeout) + response = await do_rpc_call(socket, request) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: if isinstance(response, Exception): @@ -255,8 +283,7 @@ async def _wait_for_server_rpc(self): await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server", - timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS) + error_message="Unable to start RPC Server") async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -308,17 +335,17 @@ async def _is_tracing_enabled_rpc(self) -> bool: async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" - - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") + with suppress(RPCClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") + with suppress(RPCClientClosedError): + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") @property def is_running(self) -> bool: @@ -393,7 +420,6 @@ async def check_health(self, await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_HEALTHY, error_message="Got Unhealthy response from RPC Server", - timeout=VLLM_RPC_HEALTH_TIMEOUT_MS, socket=socket) async def encode(self, *args, diff --git a/vllm/envs.py b/vllm/envs.py index 4f7a7ad7821d5..24e09ee0e055f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -56,6 +56,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False + VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None @@ -374,6 +375,11 @@ def get_default_config_root(): (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in ("1", "true")), + # Time in ms for the zmq client to wait for a response from the backend + # server for simple data operations + "VLLM_RPC_GET_DATA_TIMEOUT_MS": + lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), + # If set, allow running the engine as a separate ray actor, # which is a deprecated feature soon to be removed. # See https://github.com/vllm-project/vllm/issues/7045 From aae74ef95c370df92584e08939a15707ea8a5d3f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 21 Aug 2024 23:42:14 -0400 Subject: [PATCH 15/16] Revert "[Kernel] Expand MoE weight loading + Add Fused Marlin MoE Kernel (#7527)" (#7764) --- CMakeLists.txt | 3 +- csrc/moe/marlin_moe_ops.cu | 1740 ----------------- csrc/moe/marlin_moe_ops.h | 12 - csrc/moe/torch_bindings.cpp | 9 - tests/weight_loading/models.txt | 2 - vllm/_custom_ops.py | 14 - .../layers/fused_moe/__init__.py | 14 +- .../layers/fused_moe/fused_moe.py | 134 +- vllm/model_executor/layers/fused_moe/layer.py | 206 +- .../compressed_tensors/compressed_tensors.py | 5 - .../compressed_tensors_moe.py | 283 --- .../model_executor/layers/quantization/fp8.py | 29 +- vllm/model_executor/model_loader/utils.py | 4 +- vllm/model_executor/models/jamba.py | 2 +- vllm/model_executor/models/mixtral.py | 1 - 15 files changed, 84 insertions(+), 2374 deletions(-) delete mode 100644 csrc/moe/marlin_moe_ops.cu delete mode 100644 csrc/moe/marlin_moe_ops.h delete mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 18e5109919104..217dc70c4b24e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -286,8 +286,7 @@ define_gpu_extension_target( set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" - "csrc/moe/topk_softmax_kernels.cu" - "csrc/moe/marlin_moe_ops.cu") + "csrc/moe/topk_softmax_kernels.cu") define_gpu_extension_target( _moe_C diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu deleted file mode 100644 index 1e170e80d2f70..0000000000000 --- a/csrc/moe/marlin_moe_ops.cu +++ /dev/null @@ -1,1740 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_moe { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -// For a given "a" of size [M,K] performs a permutation of the K columns based -// on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; - int finish_row = start_row + block_rows; - if (finish_row > size_m) { - finish_row = size_m; - } - int cur_block_rows = finish_row - start_row; - - int row_stride = size_k * sizeof(half) / 16; - - auto permute_row = [&](int row) { - int iters = size_k / blockDim.x; - int rest = size_k % blockDim.x; - - int offset = row * row_stride; - - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); - - int base_k = 0; - - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - - base_k += blockDim.x; - } - - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - } - } - }; - - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); - } - } -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - int expert_id = threadIdx.x; - int num_experts = blockDim.x; - - int occurrences = 0; - for (int i = 0; i < topk_length; ++i) { - occurrences += (topk_ids[i] == expert_id); - } - expert_offsets[expert_id + 1] = occurrences; - __syncthreads(); - - if (threadIdx.x == 0) { - int tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; - expert_offsets[i + 1] = tot_offset; - } - } - __syncthreads(); -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ inline void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - constexpr int sorted_sh_stride = threads; - constexpr int sorted_gl_stride = threads; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x; - int b_sh_rd = threadIdx.x; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (group_blocks == -1 || group_blocks == 0) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - int* sh_sorted = (int*)(sh_s + shs_size); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // TODO we are currently hitting illegal memory accesses when fetching - // sorted_ids to shared data: fix this - auto fetch_sorted_ids_to_shared = [&]() { - const int mpt = ceildiv(prob_m, threads); - for (int i = 0; i < mpt; i++) { - if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { - sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = - sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; - } - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - - FragB frag_b0 = dequant(b_quant); - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - FragB frag_b1 = dequant(b_quant_shift); - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - __syncthreads(); - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1) { - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par // maximum parallelism -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / 64; - par = min((16 * max_block - pad) / 64, max_par); - prob_m = 64 * par; - m_block_ctr += 4 * (par - 1); - max_block = 4; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} - -#else - -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par // maximum parallelism -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -// const int SHARED_MEM = -// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = 0; - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - - } else { - if (group_size == -1) { - group_blocks = -1; - } else { - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - int tot_m = prob_m; - - const int* topk_ids_ptr = (const int*)topk_ids; - int* expert_offsets_ptr = (int*)expert_offsets; - compute_expert_offsets<<<1, num_experts, 0, stream>>>( - topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); - - bool do_permute_a = has_act_order; - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by - // having a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - - for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - const int4* A_ptr = (const int4*)A; - int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; - int4* C_ptr = (int4*)C; - const float* topk_weights_ptr = (const float*)topk_weights; - const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = - (const int4*)s + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 8) * - expert_idx; - const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; - const int* perm_ptr = (const int*)perm + prob_k * expert_idx; - int* locks = (int*)workspace; - - if (do_permute_a) { - // Permute A columns - int topk_rows = replicate_input ? tot_m : tot_m * topk; - int block_rows = ceildiv(topk_rows, blocks); - permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); - A_ptr = a_tmp_ptr; - } - - int max_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { - // Define kernel configurations - - // make it max possible value - int thread_m_blocks = 4; - - if (false) { - } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - } - } -} - -} // namespace marlin_moe - -torch::Tensor marlin_gemm_moe( - const torch::Tensor& a, const torch::Tensor& b_q_weights, - const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, - const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights) { - int max_par = 4; - - int dev = a.get_device(); - - auto options_dtype = - torch::TensorOptions().dtype(a.dtype()).device(a.device()); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); - torch::Tensor a_tmp = - replicate_input ? torch::zeros({size_m, size_k}, options_dtype) - : torch::zeros({size_m, topk, size_k}, options_dtype); - torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - bool has_act_order = g_idx.size(1) != 0; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); - TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), - " is not size_n = ", size_n); - num_groups = b_scales.size(1); - - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - - } else { - if (num_groups > 1) { - TORCH_CHECK( - size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); - group_size = size_k / num_groups; - } else { - group_size = -1; - } - } - - marlin_moe::marlin_mm_moe_f16i4( - a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), - topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, max_par, replicate_input, apply_weights); - return c; -} \ No newline at end of file diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h deleted file mode 100644 index 01ba8ff69850d..0000000000000 --- a/csrc/moe/marlin_moe_ops.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -torch::Tensor marlin_gemm_moe( - const torch::Tensor& a, const torch::Tensor& b_q_weights, - const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, - const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index cda1405b4e4f1..86e42af44df15 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,6 +1,5 @@ #include "core/registration.h" #include "moe_ops.h" -#include "marlin_moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. @@ -8,14 +7,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); - m.def( - "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " - "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); - - m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index c074b4b44c768..064dbb1feee83 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -13,7 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ae90af563c0cf..b89a90ef0f70c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,20 +300,6 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) -def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - num_experts = b_q_weight.shape[0] - assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), - device=b_q_weight.device, - dtype=b_q_weight.dtype) - for e in range(num_experts): - output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], - size_k, size_n, num_bits) - return output - - def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index fd6f41b90042e..3e0767c7d2665 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,17 +1,19 @@ -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) from vllm.triton_utils import HAS_TRITON -__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", +] if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ += [ - "fused_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d2b152320e11e..bcf25d2631042 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,16 +323,21 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config(M: int, E: int, N: int, K: int, topk: int, - dtype: Optional[str], - is_marlin: bool) -> Dict[str, int]: +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } - if M <= E or (is_marlin and M <= 32): + if M <= E: config = { 'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, @@ -342,14 +347,14 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int, return config -def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, - Any]] = None, - is_marlin: bool = False): +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, +): if override_config: config = override_config else: @@ -363,8 +368,7 @@ def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin) + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) return config @@ -437,108 +441,6 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - - #TODO fp8 is not implemented yet - assert not use_fp8 - - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3a77bf30131f9..4e29ab701b937 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from enum import Enum from typing import List, Optional, Tuple import torch @@ -16,12 +15,6 @@ logger = init_logger(__name__) -class FusedMoeWeightScaleSupported(Enum): - TENSOR = "tensor" - CHANNEL = "channel" - GROUP = "group" - - class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod @@ -206,182 +199,55 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) - def _load_per_tensor_weight_scale(self, shard_id: str, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - expert_id: int): - param_data = param.data - # for per tensor weight quantization - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) - elif shard_id == "w2": - param_data[expert_id] = loaded_weight - - def _load_model_weight_or_group_weight_scale(self, shard_dim: int, - expert_data: torch.Tensor, - shard_id: str, - loaded_weight: torch.tensor, - tp_rank: int): - # Load grouped weight scales for group quantization - # or model weights - if shard_id == "w2": - self._load_w2(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - - def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, - shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, - tp_rank: int): - # for per channel weight quantization - if shard_id == "w2": - expert_data.copy_(loaded_weight) - elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - - def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): - - # Index the loaded weight for tp sharding. - # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) - # Narrow parameter and load. - # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) - # w3, up_proj: Load into second logical weight of w13. - else: - assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - expert_data.copy_(loaded_weight) - - def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): - - # Index the loaded weight for tp sharding. - # down_proj: "RowParallel" so tp sharding on input_dim - # Narrow parameter and load. - shard_size = expert_data.shape[shard_dim] - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) - # w2, down_proj: Load into only logical weight of w2. - expert_data.copy_(loaded_weight) - - def _load_single_value(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, expert_id: int): - param_data = param.data - - # Input scales can be loaded directly and should be equal. - param_data[expert_id] = loaded_weight - def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: - if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") - WEIGHT_SCALE_SUPPORTED = [ - e.value for e in FusedMoeWeightScaleSupported - ] - # Fetch the dim to shard the parameter/loaded weight - # based on the shard id. This will be whatever - # dimension intermediate_size is used. - SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + # Special case for fp8 scales. + if getattr(param, "is_fp8_scale", False): + self._load_fp8_scale(param.data, loaded_weight, weight_name, + shard_id, expert_id) + return expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # is_transposed: whether or not the parameter is transposed on disk - # If transposed, the loaded weight will be transposed and the dim - # to shard the loaded weight will be flipped. - is_transposed = getattr(param, "is_transposed", False) - shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] - if is_transposed: - loaded_weight = loaded_weight.t().contiguous() - shard_dim = ~shard_dim - - # Case weight_scales - if "weight_scale" in weight_name: - # load the weight scaling based on the quantization scheme - # supported weight scales can be found in - # FusedMoeWeightScaleSupported - # TODO @dsikka: once hardened, refactor to use vLLM Parameters - # specific to each case - quant_method = getattr(param, "quant_method", None) - if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: - self._load_per_channel_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - else: - raise ValueError( - f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") - return - - if "weight_shape" in weight_name: - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - return + # If transposed, weight is saved as [input_dim, output_dim] + # Otherwise, weight is saved as [output_dim, input_dim] + # Default is not transposed/input dim is dim 1 + input_dim = getattr(param, "input_dim", 1) + output_dim = getattr(param, "output_dim", 0) - # Case input scale - if "input_scale" in weight_name: - # Note: input_scale loading is only supported for fp8 - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") - - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - return + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + if shard_id == "w2": + shard_dim = input_dim + shard_size = expert_data.shape[shard_dim] + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + elif shard_id in ("w1", "w3"): + shard_dim = output_dim + shard_size = expert_data.shape[output_dim] // 2 + offset = shard_size * tp_rank + loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) - # Case model weights - if "weight" in weight_name: - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - return + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + expert_data.copy_(loaded_weight) + # w3, up_proj: Load into second logical weight of w13. + elif shard_id == "w3": + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + # w2, down_proj: Load into only logical weight of w2. + elif shard_id == "w2": + expert_data.copy_(loaded_weight) + else: + raise ValueError( + f"Expected shard_id w1,w2 or w3 but got {shard_id}") @staticmethod def select_experts(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 759dd9c0dd4ef..ae75781927381 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -3,12 +3,9 @@ import torch from pydantic import BaseModel -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 - CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsUnquantized, @@ -67,8 +64,6 @@ def get_quant_method( return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) - if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod(self) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py deleted file mode 100644 index 0e0ab9ce9169f..0000000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ /dev/null @@ -1,283 +0,0 @@ -import enum -from enum import Enum -from typing import List, Optional - -import torch - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - CompressionFormat) -from vllm.model_executor.utils import set_weight_attrs - - -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() - - -__all__ = ["CompressedTensorsMoEMethod"] - - -class CompressedTensorsMoEMethod(FusedMoEMethodBase): - - def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 - ): - self.quant_config = quant_config - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - config = self.quant_config.target_scheme_map["Linear"].get("weights") - self.num_bits = config.num_bits - self.packed_factor = 32 // config.num_bits - self.strategy = config.strategy.value - self.group_size = config.group_size - assert config.symmetric, ( - "Only symmetric quantization is supported for MoE") - - if not (self.quant_config.quant_format - == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): - raise ValueError("For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - - # Will transpose the loaded weight along the - # intermediate and hidden dim sizes. Will - # shard for TP along the transposed dims - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": self.strategy - }) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size // - self.packed_factor, - 2 * intermediate_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - intermediate_size // - self.packed_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - if self.strategy == "channel": - num_groups_w2 = num_groups_w13 = 1 - self.group_size = -1 - else: - num_groups_w2 = intermediate_size // self.group_size - num_groups_w13 = hidden_size // self.group_size - - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - num_groups_w13, - 2 * intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_scale) - set_weight_attrs(w13_scale, extra_weight_attrs) - - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_scale) - set_weight_attrs(w2_scale, extra_weight_attrs) - - w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) - layer.register_parameter("w2_weight_shape", w2_weight_shape) - set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) - - layer.register_parameter("w13_weight_shape", w13_weight_shape) - set_weight_attrs(w13_weight_shape, extra_weight_attrs) - - w13_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx", w13_g_idx) - set_weight_attrs(w13_g_idx, extra_weight_attrs) - - w2_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx", w2_g_idx) - set_weight_attrs(w2_g_idx, extra_weight_attrs) - - w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) - - w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) - set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) - - layer.a13_scale = None - layer.a2_scale = None - layer.marlin_state = GPTQMarlinState.REPACK - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, - scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - num_experts = s.shape[0] - output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - size_k2 = layer.w2_weight_packed.shape[2] - size_k13 = layer.w13_weight_packed.shape[2] - - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_weight_packed, - layer.w13_g_idx_sort_indices, - layer.w13_weight_packed.shape[1] * self.packed_factor, - layer.w13_weight_packed.shape[2], - self.num_bits, - ) - replace_tensor("w13_weight_packed", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_weight_packed, - layer.w2_g_idx_sort_indices, - layer.w2_weight_packed.shape[1] * self.packed_factor, - layer.w2_weight_packed.shape[2], - self.num_bits, - ) - replace_tensor("w2_weight_packed", marlin_w2_qweight) - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_weight_scale, - size_k13, - layer.w13_weight_scale.shape[2], - self.group_size, - self.num_bits, - ) - replace_tensor("w13_weight_scale", marlin_w13_scales) - marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_weight_scale, - layer.w2_weight_scale.shape[1] * self.packed_factor, - size_k2, - self.group_size, - self.num_bits, - ) - replace_tensor("w2_weight_scale", marlin_w2_scales) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_marlin_moe) - - return fused_marlin_moe(x, - layer.w13_weight_packed, - layer.w2_weight_packed, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7f45a20bd9dd9..fd7682a1c0f51 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,8 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -319,16 +318,19 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) + set_weight_attrs(w2_weight_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -341,14 +343,19 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) + set_weight_attrs(w13_input_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) w2_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - + set_weight_attrs(w2_input_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) else: layer.w13_input_scale = None layer.w2_input_scale = None diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 4bb943ab3afe4..331b859d2adec 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,11 +23,11 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] if (model_config.quantization is not None - and model_config.quantization not in mixtral_supported + and model_config.quantization != "fp8" and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] + return ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index caeda4e42d8a0..b82eb14fb5f23 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -920,7 +920,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - name, + weight_name, shard_id=shard_id, expert_id=expert_id) break diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 413783ba4b259..34f581ac78582 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -73,7 +73,6 @@ def __init__(self, self.hidden_size = hidden_size # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False, From eeee1c3b1ae30a9714dffe7a58bdbed10b1e2e38 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 21 Aug 2024 21:31:49 -0700 Subject: [PATCH 16/16] [TPU] Avoid initializing TPU runtime in is_tpu (#7763) --- vllm/platforms/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 958f6c516a2f8..aedf3c3a950ee 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -8,8 +8,10 @@ is_tpu = False try: - import torch_xla.core.xla_model as xm - xm.xla_device(devkind="TPU") + # While it's technically possible to install libtpu on a non-TPU machine, + # this is a very uncommon scenario. Therefore, we assume that libtpu is + # installed if and only if the machine has TPUs. + import libtpu # noqa: F401 is_tpu = True except Exception: pass