diff --git a/examples/basic-ub.py b/examples/basic-ub.py new file mode 100644 index 000000000000..3f6fd2fdb82e --- /dev/null +++ b/examples/basic-ub.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Configure logging level for vllm (optional, uses VLLM_LOGGING_LEVEL env var). +logging_level = os.getenv("VLLM_LOGGING_LEVEL", "").upper() +if logging_level: + logging.basicConfig(level=getattr(logging, logging_level, logging.INFO)) + +# Create a sampling params object, optionally limiting output tokens via MAX_TOKENS env var. +param_kwargs = {"temperature": 0.8, "top_p": 0.95} +max_tokens_env = os.getenv("MAX_TOKENS") +if max_tokens_env is not None: + try: + param_kwargs["max_tokens"] = int(max_tokens_env) + except ValueError: + raise ValueError(f"Invalid MAX_TOKENS value: {max_tokens_env}") +sampling_params = SamplingParams(**param_kwargs) + + +def main(): + # Create an LLM. + model = "deepseek-ai/DeepSeek-V2-Lite" + # model = "facebook/opt-125m" + llm = LLM(model=model, + enforce_eager=True, + compilation_config=2, + ############### + trust_remote_code=True, + max_model_len=1024, + #load_format="dummy", + ############### + #tensor_parallel_size=1, + data_parallel_size=2, + enable_expert_parallel=True, + ############### + #enable_microbatching=True, + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index dbf8ed58cc47..9ea1b9997ad7 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -32,66 +32,43 @@ import os from time import sleep -from vllm import LLM, SamplingParams -from vllm.utils import get_open_port +from vllm import LLM, EngineArgs, SamplingParams +from vllm.utils import FlexibleArgumentParser, get_open_port def parse_args(): - import argparse - - parser = argparse.ArgumentParser(description="Data Parallel Inference") - parser.add_argument( - "--model", - type=str, - default="ibm-research/PowerMoE-3b", - help="Model name or path", - ) - parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size") - parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size") - parser.add_argument( - "--node-size", type=int, default=1, help="Total number of nodes" - ) - parser.add_argument( - "--node-rank", type=int, default=0, help="Rank of the current node" - ) - parser.add_argument( - "--master-addr", type=str, default="", help="Master node IP address" - ) - parser.add_argument("--master-port", type=int, default=0, help="Master node port") - parser.add_argument( - "--enforce-eager", action="store_true", help="Enforce eager mode execution." - ) - parser.add_argument( - "--trust-remote-code", action="store_true", help="Trust remote code." - ) - parser.add_argument( - "--max-num-seqs", - type=int, - default=64, - help=("Maximum number of sequences to be processed in a single iteration."), - ) - parser.add_argument( - "--gpu-memory-utilization", - type=float, - default=0.8, - help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), - ) + parser = FlexibleArgumentParser() + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="ibm-research/PowerMoE-3b") + parser.add_argument("--dp-size", + type=int, + default=2, + help="Data parallel size") + parser.add_argument("--tp-size", + type=int, + default=2, + help="Tensor parallel size") + parser.add_argument("--node-size", + type=int, + default=1, + help="Total number of nodes") + parser.add_argument("--node-rank", + type=int, + default=0, + help="Rank of the current node") + parser.add_argument("--master-addr", + type=str, + default="", + help="Master node IP address") + parser.add_argument("--master-port", + type=int, + default=0, + help="Master node port") return parser.parse_args() -def main( - model, - dp_size, - local_dp_rank, - global_dp_rank, - dp_master_ip, - dp_master_port, - GPUs_per_dp_rank, - enforce_eager, - trust_remote_code, - max_num_seqs, - gpu_memory_utilization, -): +def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, + dp_master_port, GPUs_per_dp_rank): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) @@ -107,7 +84,11 @@ def main( "The president of the United States is", "The capital of France is", "The future of AI is", - ] * 100 + ] * 10 + # import random + # import string + # prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)] + # with DP, each rank should process different prompts. # usually all the DP ranks process a full dataset, @@ -131,18 +112,18 @@ def start(rank): # sampling params. here we set different max_tokens for different # ranks for demonstration. sampling_params = SamplingParams( - temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2] + temperature=0.8, top_p=0.95, max_tokens=[20, 16][global_dp_rank % 2] ) + # Fixed params + args.pop("tensor_parallel_size") + args.pop("enable_expert_parallel") + # Create an LLM. llm = LLM( - model=model, tensor_parallel_size=GPUs_per_dp_rank, - enforce_eager=enforce_eager, enable_expert_parallel=True, - trust_remote_code=trust_remote_code, - max_num_seqs=max_num_seqs, - gpu_memory_utilization=gpu_memory_utilization, + **args, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -162,19 +143,22 @@ def start(rank): if __name__ == "__main__": - args = parse_args() - dp_size = args.dp_size - tp_size = args.tp_size - node_size = args.node_size - node_rank = args.node_rank + args = vars(parse_args()) + + dp_size = args.pop("dp_size") + tp_size = args.pop("tp_size") + node_size = args.pop("node_size") + node_rank = args.pop("node_rank") if node_size == 1: dp_master_ip = "127.0.0.1" dp_master_port = get_open_port() + args.pop("master_addr") + args.pop("master_port") else: - dp_master_ip = args.master_addr - dp_master_port = args.master_port + dp_master_ip = args.pop("master_addr") + dp_master_port = args.pop("master_port") assert dp_size % node_size == 0, "dp_size should be divisible by node_size" dp_per_node = dp_size // node_size @@ -183,29 +167,22 @@ def start(rank): procs = [] for local_dp_rank, global_dp_rank in enumerate( - range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node) - ): - proc = Process( - target=main, - args=( - args.model, - dp_size, - local_dp_rank, - global_dp_rank, - dp_master_ip, - dp_master_port, - tp_size, - args.enforce_eager, - args.trust_remote_code, - args.max_num_seqs, - args.gpu_memory_utilization, - ), - ) + range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): + proc = Process(target=main, + args=( + args, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + tp_size, + )) proc.start() procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=300) + proc.join(timeout=1200) if proc.exitcode is None: print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") proc.kill() diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 8c49ea6cc107..18c3dfe0f171 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -106,6 +106,7 @@ def check_for_ending_compilation(self): end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: + # logger.info("CUDA BACKEND CALL") if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 05e4ca9f08b3..54d5af2ad29d 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -157,6 +157,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() + self.do_not_compile = True if self.do_not_compile: return compilation_counter.num_models_seen += 1 @@ -170,6 +171,7 @@ def __call__(self, *args, **kwargs): # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. if self.do_not_compile or torch.compiler.is_compiling(): + # logger.info("SKIPPING COMPILATION") return self.forward(*args, **kwargs) # the first compilation needs to have dynamic shapes marked diff --git a/vllm/config.py b/vllm/config.py index 46a5bf34f66e..f6e416e5ad65 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1816,6 +1816,18 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_microbatching: bool = False + """Enable microbatching for the model executor.""" + + always_microbatch_if_enabled: bool = True + """Always microbatch if microbatching is enabled. Easier to sync between + dp workers.""" + + microbatching_token_threshold: int = 4 + """The threshold for microbatching. If the number of tokens in the + request is greater than this threshold, microbatching will be used. + Otherwise, the request will be processed in a single batch.""" + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None """This parameter is deprecated and will be removed in a future release. Please remove it from your configs""" @@ -4564,6 +4576,20 @@ def __post_init__(self): "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True + if self.parallel_config.enable_microbatching and \ + self.compilation_config.level >= CompilationLevel.PIECEWISE: + # Microbatching is not supported with piecewise compilation yet. + # More specifically piecewise cuda-graphs + logger.warning_once( + "Piecewise compilation is not supported with " + "microbatching. Disabling piecewiseching compilation.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + if not self.model_config.enforce_eager: + self.compilation_config.full_cuda_graph = True + logger.warning_once( + "Enabling fullcudagraphs for microbatching" + ) + disable_chunked_prefill_reasons: list[str] = [] if self.model_config and self.model_config.pooler_config: diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 85f87cb21edc..f64ff0014b2d 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -101,13 +101,23 @@ def __init__(self, cpu_group): logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) - self.handle_cache = Cache() + # self.handle_cache = Cache() + self.handle_caches = [Cache(), Cache()] def get_handle(self, kwargs): import pplx_kernels as pplx - return self.handle_cache.get_or_create( + return self.handle_caches[0].get_or_create( kwargs, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode) + + def get_handles(self, kwargs): + import pplx_kernels as pplx + first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + return [first_handle, second_handle] + def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -117,9 +127,10 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: raise NotImplementedError def destroy(self): - with self.handle_cache._lock: - for _, handle in self.handle_cache._cache.items(): - handle.destroy() + for handle_cache in self.handle_caches: + with handle_cache._lock: + for _, handle in handle_cache._cache.items(): + handle.destroy() if self.internode: from pplx_kernels.nvshmem import nvshmem_finalize @@ -136,7 +147,7 @@ def __init__(self, cpu_group): assert has_deep_ep( ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa super().__init__(cpu_group) - self.handle_cache = Cache() + self.handle_caches = [Cache(), Cache()] # This is the DeepEP default. Stick to it till we can establish # reasonable defaults based on profiling. @@ -163,6 +174,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) + self.handle_cache = self.handle_caches[0] def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. @@ -254,7 +266,7 @@ def get_handle(self, kwargs): import deep_ep buffer_kwargs = self._make_all2all_kwargs(**kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs) - handle: deep_ep.Buffer = self.handle_cache.get_or_create( + handle: deep_ep.Buffer = self.handle_caches[0].get_or_create( buffer_kwargs, deep_ep.Buffer) # It is dangerous to set num sms outside this function. num_sms is not # a part of the hash-key that identifies this object. If we are in a @@ -262,3 +274,10 @@ def get_handle(self, kwargs): # in get_or_create must be updated. handle.set_num_sms(self.num_sms) return handle + + def get_handles(self, kwargs): + import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) + first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer) + second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer) + return [first_handle, second_handle] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2d3783363c00..673e7da7f4c0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -323,6 +323,7 @@ class EngineArgs: data_parallel_rpc_port: Optional[int] = None data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_microbatching: bool = ParallelConfig.enable_microbatching enable_eplb: bool = ParallelConfig.enable_eplb num_redundant_experts: int = ParallelConfig.num_redundant_experts eplb_window_size: int = ParallelConfig.eplb_window_size @@ -674,6 +675,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--enable-microbatching", + **parallel_kwargs["enable_microbatching"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--num-redundant-experts", @@ -1153,6 +1156,7 @@ def create_engine_config( data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, + enable_microbatching=self.enable_microbatching, enable_eplb=self.enable_eplb, num_redundant_experts=self.num_redundant_experts, eplb_window_size=self.eplb_window_size, diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index f3aee188dae9..69f8d022619b 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -40,11 +40,11 @@ def log_inputs( if prompt_token_ids is not None: prompt_token_ids = prompt_token_ids[:max_log_len] - logger.info( - "Received request %s: prompt: %r, " - "params: %s, prompt_token_ids: %s, " - "prompt_embeds shape: %s, " - "lora_request: %s, prompt_adapter_request: %s.", request_id, - prompt, params, prompt_token_ids, - prompt_embeds.shape if prompt_embeds is not None else None, - lora_request, prompt_adapter_request) + # logger.info( + # "Received request %s: prompt: %r, " + # "params: %s, prompt_token_ids: %s, " + # "prompt_embeds shape: %s, " + # "lora_request: %s, prompt_adapter_request: %s.", request_id, + # prompt, params, prompt_token_ids, + # prompt_embeds.shape if prompt_embeds is not None else None, + # lora_request, prompt_adapter_request) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feeaf..4671db497113 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -44,9 +44,31 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, device="cpu", dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group + # logger.info("STARTING AR num_tokens_across_dp") dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + # logger.info("finishing num_tokens_across_dp") return num_tokens_tensor + @staticmethod + def should_ubatch_across_dp(should_ubatch: bool, dp_size: int, dp_rank: int) -> bool: + should_ubatch_across_dp = [0] * dp_size + should_ubatch_across_dp[dp_rank] = 1 if should_ubatch else 0 + should_ubatch_tensor = torch.tensor(should_ubatch_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + # logger.info(f"should_ubatch_tensor before ar {should_ubatch_tensor}") + dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group) + # logger.info(f"should_ubatch_tensor after ar {should_ubatch_tensor}") + + # If there's an incorrect ordering of ARs across DP ranks, this tensor + # can end up containing the number of padded tokens for a DP rank + assert torch.all(should_ubatch_tensor <= 1) + + result: bool = bool(torch.all(should_ubatch_tensor == 1).item()) + # print(f"FINISHING AR should_ubatch_across_dp {result} {should_ubatch_tensor}") + return result + @staticmethod def make( parallel_config: ParallelConfig, @@ -69,6 +91,7 @@ def make( # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize + # print(f"num_tokens_across_dp {num_tokens_across_dp} batchsize {batchsize}") assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] == batchsize) if num_tokens_across_dp is None: @@ -108,6 +131,42 @@ def get_forward_context() -> ForwardContext: return _forward_context +def create_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + skip_cuda_graphs: bool = False): + dp_metadata: Optional[DPMetadata] = None + if vllm_config.parallel_config.data_parallel_size > 1 and ( + attn_metadata is not None or num_tokens is not None): + dp_metadata = DPMetadata.make(vllm_config.parallel_config, + attn_metadata, num_tokens or 0, + num_tokens_across_dp) + + return ForwardContext(no_compile_layers=vllm_config.compilation_config. + static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata, + skip_cuda_graphs=skip_cuda_graphs) + + +@contextmanager +def override_forward_context(forward_context: Optional[ForwardContext]): + """A context manager that overrides the current forward context. + This is used to override the forward context for a specific + forward pass. + """ + global _forward_context + prev_context = _forward_context + _forward_context = forward_context + try: + yield + finally: + _forward_context = prev_context + + @contextmanager def set_forward_context( attn_metadata: Any, @@ -125,26 +184,15 @@ def set_forward_context( need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() - dp_metadata: Optional[DPMetadata] = None - if vllm_config.parallel_config.data_parallel_size > 1 and ( - attn_metadata is not None or num_tokens is not None): - dp_metadata = DPMetadata.make(vllm_config.parallel_config, - attn_metadata, num_tokens or 0, - num_tokens_across_dp) - global _forward_context - prev_context = _forward_context - _forward_context = ForwardContext( - no_compile_layers=vllm_config.compilation_config. - static_forward_context, - virtual_engine=virtual_engine, - attn_metadata=attn_metadata, - dp_metadata=dp_metadata, - skip_cuda_graphs=skip_cuda_graphs, - ) + forward_context = create_forward_context(attn_metadata, vllm_config, + virtual_engine, num_tokens, + num_tokens_across_dp, + skip_cuda_graphs) try: - yield + with override_forward_context(forward_context): + yield finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: @@ -181,5 +229,3 @@ def set_forward_context( logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats) - - _forward_context = prev_context diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 5a8accd80463..d98e60779579 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -7,6 +7,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.v1.worker.ubatching import ( + get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl, + yield_and_switch_from_compute_to_comm_impl) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -38,7 +41,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168] def __init__(self, - buffer: deep_ep.Buffer, + buffers: list[deep_ep.Buffer], world_size: int, dp_size: int, max_tokens_per_rank: int, @@ -47,7 +50,7 @@ def __init__(self, use_fp8_dispatch: bool = False): super().__init__() - self.buffer = buffer + self.buffers = buffers self.world_size = world_size self.dp_size = dp_size self.quant_dtype = quant_dtype @@ -57,7 +60,7 @@ def __init__(self, # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. - self.handle = None + self.handles: list[Optional[tuple]] = [None, None] def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_tokens_per_rank @@ -127,9 +130,12 @@ def prepare( Optional[torch.Tensor], Optional[torch.Tensor]]: hidden_size = a1.size(1) - assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ - (f"Hidden Size {hidden_size} not in supported list of hidden sizes" - f"{self.SUPPORTED_HIDDEN_SIZES}") + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id + # assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ + # (f"Hidden Size {hidden_size} not in supported list of hidden sizes" + # f"{self.SUPPORTED_HIDDEN_SIZES}") if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ @@ -149,14 +155,17 @@ def prepare( a1 = a1 * rank_topk_weights.to(a1.dtype) # Dispatch - expert_x, expert_num_tokens, self.handle, event, hook = \ - self.buffer.low_latency_dispatch(a1, + # yield_and_switch_from_compute_to_comm_impl(schedule="default") + expert_x, expert_num_tokens, handle, event, hook = \ + self.buffers[a2a_idx].low_latency_dispatch(a1, rank_topk_ids, self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, return_recv_hook=False) + self.handles[a2a_idx] = handle + # yield_and_switch_from_comm_to_compute_impl(schedule="default") expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, a1.dtype) @@ -167,7 +176,11 @@ def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool) -> None: - assert self.handle is not None + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id + handle = self.handles[a2a_idx] + assert handle is not None combine_topk_weights = topk_weights if apply_router_weight_on_input: @@ -175,12 +188,16 @@ def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, combine_topk_weights = torch.ones_like(topk_weights) # TODO (varun) : Enable zero copy mode - _, event, hook = self.buffer.low_latency_combine( + # yield_and_switch_from_compute_to_comm_impl(schedule="default") + _, event, hook = self.buffers[a2a_idx].low_latency_combine( fused_expert_output, topk_ids, combine_topk_weights, - self.handle, + handle, async_finish=False, zero_copy=False, return_recv_hook=False, out=output) + # event.current_stream_wait() + # yield_and_switch_from_comm_to_compute_impl(schedule="default") + diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f22884b8a1a5..3e9ef23b6ef2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -903,7 +903,7 @@ def fused_topk( # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e6f555d315d8..e1d184d89e95 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -32,6 +32,8 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from vllm.v1.worker.ubatching import get_current_ubatch_context + if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -333,13 +335,13 @@ def init_prepare_finalize(self, moe: MoEConfig, all_to_all_args[ "group_name"] = all2all_manager.cpu_group.group_name - handle = all2all_manager.get_handle(all_to_all_args) + handles = all2all_manager.get_handles(all_to_all_args) input_activations = get_quant_config_input_activations( quant_config) prepare_finalize = PplxPrepareAndFinalize( - handle, + handles, max_num_tokens=moe.max_num_tokens, world_size=all2all_manager.world_size, rank=all2all_manager.rank, @@ -376,11 +378,11 @@ def init_prepare_finalize(self, moe: MoEConfig, num_global_experts=moe.num_experts, num_local_experts=moe.num_experts // all2all_manager.world_size) - handle = all2all_manager.get_handle(all_to_all_args) + handles = all2all_manager.get_handles(all_to_all_args) # Note : We may want to use FP8 dispatch even otherwise just to # reduce datamovement - assert act_quant_block_size is not None + #assert act_quant_block_size is not None use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() and act_quant_block_size[1] == DEEPEP_QUANT_BLOCK_SIZE) @@ -388,7 +390,7 @@ def init_prepare_finalize(self, moe: MoEConfig, # Note (varun): Whether to use FP8 dispatch or not needs some # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( - handle, + handles, world_size=all2all_manager.world_size, dp_size=all2all_manager.dp_world_size, max_tokens_per_rank=moe.max_num_tokens, @@ -1000,13 +1002,13 @@ def __init__( or self.moe_parallel_config.use_deepep_ll_kernels): act_dtype = vllm_config.model_config.dtype self.batched_hidden_states = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size), + (2, envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size), dtype=act_dtype, device=torch.cuda.current_device()) # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts), + (2, envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts), dtype=act_dtype, device=torch.cuda.current_device()) @@ -1566,15 +1568,19 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - - assert (self.batched_hidden_states.size(0) # type: ignore + + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id + batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[batch_buffer_idx, :] + + assert (batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (batched_router_logits.size(0) # type: ignore >= chunk_size) - staged_hidden_states = self.batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = self.batched_router_logits[: - chunk_size, :] # type: ignore + staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 2ff8ef99b2ec..6e577cfd9e04 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.v1.worker.ubatching import ( + get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl, + yield_and_switch_from_compute_to_comm_impl) # The max_num_tokens, world_size and dp_size must be the same @@ -15,7 +18,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__(self, - a2a: pplx.AllToAll, + a2as: list[pplx.AllToAll], max_num_tokens: int, world_size: int, rank: int, @@ -25,7 +28,7 @@ def __init__(self, per_act_token: bool = False): super().__init__() assert max_num_tokens > 0 - self.a2a = a2a + self.a2as = a2as self.block_shape = block_shape self.max_num_tokens = max_num_tokens self.world_size = world_size @@ -54,6 +57,9 @@ def prepare( Optional[torch.Tensor], Optional[torch.Tensor]]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id assert rank_topk_ids.size(0) == num_tokens # assert expert_map is None, "NYI" @@ -115,15 +121,28 @@ def prepare( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - self.a2a.dispatch( - out_expert_num_tokens=expert_num_tokens, - out_expert_x=expert_x, - out_expert_x_scale=expert_x_scale, - dp_x=a1q, - dp_x_scale=a1q_scale, - indices=rank_topk_ids, - bound_m=bound_m, - ) + def dispatch(send: bool): + self.a2as[a2a_idx].dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=rank_topk_ids, + bound_m=bound_m, + do_send=send, + do_recv=not send, + ) + + yield_and_switch_from_compute_to_comm_impl(schedule="default") + dispatch(True) # Send + # torch.cuda.synchronize() + # print(f"{ubatch_id} AFTER SEND SYNC", flush=True) + dispatch(False) # Recv + # torch.cuda.synchronize() + # print(f"{ubatch_id} AFTER RECV SYNC", flush=True) + yield_and_switch_from_comm_to_compute_impl(schedule="default") + # torch.cuda.synchronize() if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, 0:1] @@ -141,6 +160,9 @@ def finalize( # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) bound_m: Optional[torch.Tensor] = None + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id assert topk_ids.size(0) == num_tokens, ( f"{topk_ids.size(0)} == {num_tokens}") @@ -152,8 +174,22 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - self.a2a.combine(out_tokens=output, - indices=topk_ids, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m) + def combine(send: bool): + self.a2as[a2a_idx].combine( + out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=send, + do_recv=not send, + ) + + yield_and_switch_from_compute_to_comm_impl(schedule="default") + combine(True) + # torch.cuda.synchronize() + # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) + combine(False) + # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) + yield_and_switch_from_comm_to_compute_impl(schedule="default") + # torch.cuda.synchronize() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 527b31153410..422e7f7daada 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -34,6 +34,8 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.attention.backends.utils import slice_query_start_locs + logger = init_logger(__name__) @@ -174,26 +176,18 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, def build( self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata + common_attn_metadata: CommonAttentionMetadata, + ubatch_id: Optional[int] = None, ) -> FlashAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 03a2ed7139c7..a057ea32053b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -426,9 +426,7 @@ def build(self, common_prefix_len: int, device = self.runner.device qo_indptr = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] - slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True).long() + block_table_tensor = common_attn_metadata.block_table_tensor block_table_bounds = (seq_lens + page_size - 1) // page_size diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index dd8d7994ed33..31236361c8e1 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -279,16 +279,11 @@ def build(self, common_prefix_len: int, num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1878ae74dbc6..e5acc52de5ad 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -434,23 +434,41 @@ def reorder_batch(self, input_batch: "InputBatch", input_batch.swap_states(prefills[i - 1], decode_idx) modified_batch = True - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - return modified_batch def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor): + seq_lens: torch.Tensor, + ubatch_id: Optional[int] = None): return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, ) + + def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, + num_tokens: int, + query_start_loc: torch.Tensor): + """ + return + - num_decodes: number of decode requests + - num_prefills: number of prefill requests + - num_decode_tokens: number of decode tokens + - num_prefill_tokens: number of prefill tokens + """ + if max_query_len == 1: + # Pure decode + return num_reqs, 0, num_tokens, 0 + else: + query_lens = query_start_loc[1:] - query_start_loc[:-1] + first_prefill = (query_lens > 1).int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > 1) + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = first_prefill + num_prefill_tokens = num_tokens - query_start_loc[first_prefill] + return (num_decodes, num_prefills, num_decode_tokens, + num_prefill_tokens) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata) -> M: """ @@ -464,49 +482,49 @@ def build_for_cudagraph_capture( m.max_query_len = 1 # decode-only - # Update state usually set in reorder_batch. - self._num_decodes = m.num_reqs - self._num_decode_tokens = m.num_actual_tokens - self._num_prefills = 0 - self._num_prefill_tokens = 0 return self.build(0, m) def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + common_attn_metadata: CommonAttentionMetadata, + ubatch_id: Optional[int] = None) -> M: num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens + num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + self._split_decodes_and_prefills( + max_query_len, num_reqs, num_tokens, query_start_loc) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens prefill_metadata = None - if self._num_prefills > 0: - reqs_start = self._num_decodes # prefill_start + if num_prefills > 0: + reqs_start = num_decodes # prefill_start - context_lens_cpu = self.runner.input_batch.\ - num_computed_tokens_cpu_tensor[reqs_start:num_reqs] + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and self._num_prefills > 0 \ + if self.chunked_prefill_enabled and num_prefills > 0 \ and max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to @@ -537,14 +555,14 @@ def build(self, common_prefix_len: int, # of `to_list`. chunk_starts = \ torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) \ + .unsqueeze(1).expand(-1, num_prefills) \ * max_context_chunk chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, + num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(chunk_seq_lens, @@ -572,25 +590,47 @@ def build(self, common_prefix_len: int, ) decode_metadata = None - if self._num_decodes > 0: + if num_decodes > 0: decode_metadata = self._build_decode( - block_table_tensor=block_table_tensor[:self._num_decodes, ...], - seq_lens=seq_lens[:self._num_decodes], + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens=seq_lens[:num_decodes], + ubatch_id=ubatch_id ) return self.metadata_cls( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, prefill=prefill_metadata, decode=decode_metadata, ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with MLA. + """ + m = common_attn_metadata + assert m.num_reqs == m.num_actual_tokens, \ + "MLA only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + # Update state usually set in reorder_batch. + # self._num_decodes = m.num_reqs + # self._num_decode_tokens = m.num_actual_tokens + # self._num_prefills = 0 + # self._num_prefill_tokens = 0 + return self.build(0, m) + def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: return common_attn_metadata.max_query_len == 1 diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index be26e0060db5..9e8c36413b25 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -63,11 +63,13 @@ def __init__(self, runner, kv_cache_spec: AttentionSpec, self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - self.cg_buf_tile_scheduler_metadata = None - self.cg_buf_num_splits = None + self.cg_buf_tile_scheduler_metadata = [None, None] + self.cg_buf_num_splits = [None, None] def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + seq_lens: torch.Tensor, ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata: + ubatch_id = 0 if ubatch_id is None else ubatch_id + assert ubatch_id < 2 tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens, @@ -75,28 +77,31 @@ def _build_decode(self, block_table_tensor: torch.Tensor, 1, # MQA for the decode path ) + # logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}") if self.runner.full_cuda_graph: + n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_tile_scheduler_metadata is None: - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - self.cg_buf_num_splits = num_splits - else: - assert self.cg_buf_num_splits is not None + if self.cg_buf_num_splits[ubatch_id] is None: + self.cg_buf_num_splits[ubatch_id] = num_splits + self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata + elif n <= self.cg_buf_num_splits[ubatch_id].size(0): + assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == + assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() == tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ + self.cg_buf_tile_scheduler_metadata[ubatch_id].\ copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id] # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) + # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] + assert n <= self.cg_buf_num_splits[ubatch_id].size(0) + num_splits_view = self.cg_buf_num_splits[ubatch_id][:n] num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s num_splits = num_splits_view return FlashMLADecodeMetadata( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8083f2002602..479b431b67ed 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,7 +4,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, Optional import numpy as np import torch @@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger +from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -31,8 +32,11 @@ class CommonAttentionMetadata: """ query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" @@ -43,6 +47,92 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" + block_table_tensor: torch.Tensor + slot_mapping: torch.Tensor + slot_mapping_cpu: torch.Tensor + + def __post_init__(self): + self.slot_mapping[:self.num_actual_tokens].copy_( + self.slot_mapping_cpu[:self.num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + self.slot_mapping[self.num_actual_tokens:].fill_(-1) + + # CUDA Graph Buffers; 2 possible slots for dual batch overlap + _cg_query_start_loc: ClassVar[list[Optional[torch.Tensor]]] = [None, None] + _cg_seq_lens: ClassVar[list[Optional[torch.Tensor]]] = [None, None] + + def compute_request_slice(self, token_slice: slice) -> slice: + """ + use the query_start_loc_cpu to find the requests that the token_slice + spans. + """ + if self.max_query_len == 1: + # Pure decode + return token_slice + else: + # Find the first query_start_loc that's greater than the token_slice.start + first_request = (self.query_start_loc_cpu >= token_slice.start).int().argmax(dim=-1).item() + last_request = (self.query_start_loc_cpu < token_slice.stop).int().argmax(dim=-1).item() + return slice(first_request, last_request) + + # Slice the current CommonAttentionMetatdata for microbatching + def _slice(self, token_slice: slice, cg_buffer_idx: Optional[int]) -> 'CommonAttentionMetadata': + request_slice = self.compute_request_slice(token_slice) + cg_idx = cg_buffer_idx + + num_requests = request_slice.stop - request_slice.start + num_actual_tokens = token_slice.stop - token_slice.start + + query_start_loc = slice_query_start_locs( + self.query_start_loc, request_slice) + query_start_loc_cpu = slice_query_start_locs( + self.query_start_loc_cpu, request_slice) + + seq_lens = self.seq_lens[request_slice] + seq_lens_cpu = self.seq_lens_cpu[request_slice] + + # If we are ending partially in a request, adjust the seqlen to account + # for the "chopped off" tokens + # Use the un-modified query_start_loc_cpu to compute the number of tokens in the last request + if self.query_start_loc_cpu[request_slice.stop] > token_slice.stop: + seq_lens_cpu[num_requests - 1] -= token_slice.stop - self.query_start_loc_cpu[request_slice.stop] + # TODO(lucas): Try to avoid CPU to GPU transfer here? + seq_lens[num_requests - 1] = seq_lens_cpu[num_requests - 1] + + if cg_idx is not None: + # if cg_buffer_idx is not none we copy into local static buffers + # to make sure the common attention metadata cudagraph compilable + # NOTE(lucas): this assumes cudagraphs are captured in descending + # order of size. + if self._cg_query_start_loc[cg_idx] is None: + self._cg_query_start_loc[cg_idx] = query_start_loc + self._cg_seq_lens[cg_idx] = seq_lens + + # Alias to appease the type checker + cg_seq_lens = self._cg_seq_lens[cg_idx] + cg_query_start_loc = self._cg_query_start_loc[cg_idx] + assert cg_seq_lens is not None and cg_query_start_loc is not None + cg_seq_lens[:num_requests].copy_( + seq_lens, non_blocking=True) + cg_query_start_loc[:num_requests+1].copy_( + query_start_loc, non_blocking=True) + seq_lens = cg_seq_lens[:num_requests] + query_start_loc = cg_query_start_loc[:num_requests+1] + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_reqs=num_requests, + num_actual_tokens=num_actual_tokens, + max_query_len=self.max_query_len, + block_table_tensor=self.block_table_tensor[request_slice], + slot_mapping=self.slot_mapping[token_slice], + slot_mapping_cpu=self.slot_mapping_cpu[token_slice], + ) M = TypeVar("M") @@ -53,7 +143,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): @abstractmethod def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + common_attn_metadata: CommonAttentionMetadata, + ubatch_id: Optional[int] = None) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. @@ -98,6 +189,13 @@ def reorder_batch(self, input_batch: "InputBatch", return False +def slice_query_start_locs( + query_start_loc: torch.Tensor, + req_slice: slice, +) -> torch.Tensor: + return query_start_loc[req_slice.start: req_slice.stop + 1] -\ + query_start_loc[req_slice.start] + def validate_kv_sharing_target(current_layer_name, target_layer_name, static_forward_context): error_msg = (f"Specified KV sharing target layer for {current_layer_name} " diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29d39de212f8..0862644b2828 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,10 +3,11 @@ import copy import gc +import threading import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union import numpy as np import torch @@ -28,7 +29,8 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, +from vllm.forward_context import (create_forward_context, get_forward_context, + override_forward_context, DPMetadata, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 @@ -46,6 +48,7 @@ GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) @@ -68,6 +71,7 @@ from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -86,6 +90,21 @@ logger = init_logger(__name__) +AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata] +# list when ubatching is enabled +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], + AttnMetadataDict] + +UbatchSlice: TypeAlias = tuple[slice, slice] +UBatchSlices: TypeAlias = list[UbatchSlice] + + +import dataclasses +@dataclasses.dataclass +class CUDAGraphMetaData: + cudagraph: torch.cuda.CUDAGraph + using_ubatching: bool + outputs: Optional[Any] = None class GPUModelRunner(LoRAModelRunnerMixin): @@ -129,6 +148,7 @@ def __init__( self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs + self.cudagraphs = {} # Model-related. self.num_query_heads = model_config.get_num_attention_heads( parallel_config) @@ -219,14 +239,18 @@ def __init__( == CompilationLevel.PIECEWISE and self.vllm_config.compilation_config.use_cudagraph and not self.model_config.enforce_eager) + self.use_cuda_graph = False + logger.info(f"self.use_cuda_graph {self.use_cuda_graph}") # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) - + logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}") self.full_cuda_graph = self.compilation_config.full_cuda_graph + # self.full_cuda_graph = True + logger.info(f"full_cuda_graph {self.full_cuda_graph}") # Cache the device properties. self._init_device_properties() @@ -556,6 +580,44 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() + def _ubatch_split( + self, + max_num_scheduled_tokens: int, + scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_reqs = self.input_batch.num_reqs + + if self.parallel_config.enable_microbatching and \ + total_num_scheduled_tokens >= \ + self.parallel_config.microbatching_token_threshold \ + and max_num_scheduled_tokens == 1: + # For pure decode we can just create ubatchs by cutting the request + # in half + b0_reqs_end = num_reqs // 2 + b0_tokens_end = total_num_scheduled_tokens // 2 + assert b0_reqs_end < num_reqs and \ + b0_tokens_end < total_num_scheduled_tokens + return [ + (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), + (slice(b0_reqs_end, num_reqs), + slice(b0_tokens_end, total_num_scheduled_tokens)), + ] + + # if self.parallel_config.enable_microbatching and \ + # self.parallel_config.always_microbatch_if_enabled: + # print(f"PREFIL RUN total_num_scheduled_tokens: {total_num_scheduled_tokens} max_num_scheduled_tokens {max_num_scheduled_tokens}") + # TODO we can do something more advanced here to try to balance, + # i.e. split to the left of `total_num_scheduled_tokens // 2` if it + # is more balanced + # req_split_id = np.argmax( + # query_start_loc_np > (total_num_scheduled_tokens // 2)) + # return [(slice(0, req_split_id), + # slice(0, query_start_loc_np[req_split_id])), + # (slice(req_split_id, num_reqs), + # slice(query_start_loc_np[req_split_id], + # total_num_scheduled_tokens))] + return None + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -579,8 +641,9 @@ def _get_cumsum_and_arange( def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + ) -> tuple[PerLayerAttnMetadata, bool, torch.Tensor, + Optional[SpecDecodeMetadata], np.ndarray, Optional[UBatchSlices], + int, Optional[torch.Tensor]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -667,6 +730,56 @@ def _prepare_inputs( self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( + max_num_scheduled_tokens, + scheduler_output) + should_ubatch = self.should_ubatch(True if ubatch_slices else False) + # Don't attempt to microbatch unless every other DP worker is also microbatching + if not should_ubatch: + ubatch_slices = None + + num_pad_tokens = 0 + num_tokens_after_padding = None + ubatch_bailout = False + if ubatch_slices: + # logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}") + assert should_ubatch + num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) + logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}") + # logger.info("UBATCH PADDING DONE") + if num_pad_tokens > 0: + if num_pad_tokens < scheduler_output.total_num_scheduled_tokens: + self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens) + else: + # We bail out of ubatching here. This accounts for the case where + # the padding would result in an "empty" second ubatch. + # TODO: just make the second ubatch a dummy ubatch + # logger.info("FALLING BACK AND DISABLING UBATCHING") + ubatch_bailout = True + + # Note that if we are attempting to ubatch by this point then we know that no + # DP ranks are doing dummy runs + if ubatch_slices: + should_ubatch = self.should_ubatch(False if ubatch_bailout else True) + if not should_ubatch: + # logger.info("SUCCESSFULLY BAILED OUT") + num_pad_tokens = 0 + num_tokens_after_padding = None + ubatch_slices = None + + + # This AR is only necessary in the case described above where + # the second ubatch ends up being empty. NOte if you delete this go delete + # the second should_ubatch call in _dummy_run + # should_ubatch = self.should_ubatch(True if ubatch_slices else False) + # if not should_ubatch: + # num_pad_tokens = 0 + # num_tokens_after_padding = None + # ubatch_slices = None + + + + self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) @@ -698,22 +811,29 @@ def _prepare_inputs( self.query_start_loc_cpu[num_reqs].item()) query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - ) - attn_metadata: dict[str, Any] = {} + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + block_table_tensor=self.input_batch.block_table[kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch.block_table[kv_cache_group_id].slot_mapping[:num_reqs], + slot_mapping_cpu=self.input_batch.block_table[kv_cache_group_id].slot_mapping_cpu[:num_reqs], + ) + # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 builder = self.attn_metadata_builders[kv_cache_group_id] @@ -726,13 +846,38 @@ def _prepare_inputs( builder, ) - attn_metadata_i = (builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) - - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + if self.vllm_config.compilation_config.full_cuda_graph: + self.input_batch.block_table[kv_cache_group_id]\ + .slot_mapping.fill_(-1) + + if ubatch_slices is not None: + for ubid, (req_slice, token_slice) in enumerate(ubatch_slices): + # Run a dummy batch if its a empty ubatch + if token_slice.stop <= token_slice.start: + attn_metadata_i = None + else: + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id]. + build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata._slice(token_slice, cg_buffer_idx=ubid), + ubatch_id=ubid + )) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + # assert attn_metadata_i is not None + # What if it's None? Do we still add it to the list? + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is dict + attn_metadata[layer_name] = attn_metadata_i attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) @@ -766,8 +911,10 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + return (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens, + ubatch_slices, num_pad_tokens, + num_tokens_after_padding) def _compute_cascade_attn_prefix_len( self, @@ -1154,10 +1301,12 @@ def apply_grammar_bitmask( ) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, + self, tokens_slice: slice, + intermediate_tensors: IntermediateTensors, sync_self: bool) -> IntermediateTensors: assert self.intermediate_tensors is not None + num_tokens = tokens_slice.stop - tokens_slice.start tp = self.vllm_config.parallel_config.tensor_parallel_size enabled_sp = self.compilation_config.pass_config. \ @@ -1169,21 +1318,24 @@ def sync_and_slice_intermediate_tensors( is_residual_scattered = tp > 1 and enabled_sp \ and num_tokens % tp == 0 + def copy_slice(is_scattered: bool) -> slice: + if is_scattered: + return slice(tokens_slice.start // tp, tokens_slice.stop // tp) + else: + return tokens_slice + # When sequence parallelism is enabled, the "residual" tensor is sharded # across tensor parallel ranks, so each rank only needs its own slice. if sync_self: assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): - is_scattered = "residual" and is_residual_scattered - copy_len = num_tokens // tp if is_scattered else \ - num_tokens - self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) + _copy_slice = copy_slice(is_residual_scattered) + self.intermediate_tensors[k][_copy_slice].copy_( + v[_copy_slice], non_blocking=True) return IntermediateTensors({ k: - v[:num_tokens // tp] - if k == "residual" and is_residual_scattered else v[:num_tokens] + v[copy_slice(k == "residual" and is_residual_scattered)] for k, v in self.intermediate_tensors.items() }) @@ -1217,7 +1369,7 @@ def get_dp_padding(self, # TODO(tms) : There are many cases where padding is enabled for # prefills, causing unnecessary and excessive padding of activations. - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: + if dp_size == 1: # Early exit. return 0, None @@ -1230,6 +1382,459 @@ def get_dp_padding(self, dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + + def get_padding(self, + num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]: + + num_tokens_padded = num_tokens_unpadded + + # logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}") + if (self.use_cuda_graph + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + from vllm.utils import round_up + num_tokens_padded = round_up(num_tokens_unpadded, tp_size) + + num_pad_tokens = num_tokens_padded - num_tokens_unpadded + num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_padded) + + return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding + + def get_dp_padding_ubatch(self, + ubatch_slices: UBatchSlices, + include_cudagraphs: bool = True) -> tuple[int, Optional[torch.Tensor]]: + dp_size = self.vllm_config.parallel_config.data_parallel_size + + if dp_size == 1: + # Early exit. + return 0, None + + first_ubatch_slice = ubatch_slices[0] + second_ubatch_slice = ubatch_slices[1] + + first_ubatch_num_tokens = first_ubatch_slice[1].stop - first_ubatch_slice[1].start + second_ubatch_num_tokens = second_ubatch_slice[1].stop - second_ubatch_slice[1].start + # We don't support prefills yet so the two ubatches should only differ + # by at most one token + assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1 + + from vllm.utils import round_up + + num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens + num_tokens_padded = round_up(num_tokens_unpadded, 2) + if (include_cudagraphs and self.use_cuda_graph + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + # Add padding to the batch size. + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_tokens_padded = round_up(num_tokens_unpadded, tp_size) + + num_tokens_per_ubatch = num_tokens_padded // 2 + + # Note that we compute the number of padded tokens per ubatch + num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_per_ubatch) + + num_pad_tokens = ((num_pad_tokens + num_tokens_per_ubatch) * 2) - \ + num_tokens_unpadded + return num_pad_tokens, num_tokens_after_padding + + # This doesn't actually pad the ubatch slices. It just shifts the + # split point to the correct value so that padding can be applied + # to the second ubatch later. Should be called after ubatch + # slicing but before attention meta data creation + def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices, + num_pad_tokens: int): + original_num_tokens = ubatch_slices[1][1].stop + assert num_pad_tokens < original_num_tokens + total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2 + padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) + padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens) + + ubatch_slices[0] = (padded_first_ubatch_slice, padded_first_ubatch_slice) + ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice) + + # if (num_pad_tokens_first_ubatch > 0): + # print(f"FIRST UBATCH PADDING {num_pad_tokens_first_ubatch} TOTAL: {max_tokens_across_dp_cpu} ORIGINAL{first_ubatch_num_tokens}") + # if (num_pad_tokens_second_ubatch > 0): + # print(f"SECOND UBATCH PADDING {num_pad_tokens_second_ubatch} TOTAL: {max_tokens_across_dp_cpu} ORIGINAL{second_ubatch_num_tokens}") + # print(f"num padded tokens: {num_pad_tokens} num tokens tensor: {num_tokens_after_padding} first num_tokens: {first_ubatch_num_tokens} second num tokens {second_ubatch_num_tokens}") + + # This is where the second ubatch is adjusted to account for the padding. + # Should be called after attention metadata creation. This just extends + # the second ubatch slice out to the total number of tokens + # (num_tokens + padding) + def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + # TODO Add asserts to make sure stage one ran + padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, num_total_tokens) + ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice) + + + # Returns num_padded_tokens. This is just a number that should be added to the + # current number of tokens. It is a sum of the number of padded tokens from DP + # padding along with the number of padded tokens from cudagraph padding. + # The second tensor object is None when DP is disabled. When DP is enabled. + # it contains the number of tokens on each dp rank + def compute_padding(self,) -> tuple[int, Optional[torch.Tensor]]: + return (0, torch.Tensor()) + + def should_ubatch(self, should_ubatch: bool) -> bool: + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, dp_rank) + + def _get_dummy_model_inputs(self, num_tokens: int) -> tuple: + # Dummy batch. (hopefully we are the last one so we can just + # update this to a one token batch and return) + + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + slice(0, num_tokens), None, False) + + + return input_ids, positions, inputs_embeds, intermediate_tensors + + def _get_model_inputs(self, tokens_slice: slice, + scheduler_output: "SchedulerOutput"): + num_tokens = tokens_slice.stop - tokens_slice.start + if num_tokens == 0: + # Dummy batch. (hopefully we are the last one so we can just + # update this to a one token batch and return) + tokens_slice = slice(tokens_slice.start, tokens_slice.start + 1) + num_tokens = 1 + + # if (self.use_cuda_graph + # and num_tokens <= self.cudagraph_batch_sizes[-1]): + # # Use piecewise CUDA graphs. + # # Add padding to the batch size. + # tokens_slice = \ + # slice(tokens_slice.start, tokens_slice.start+ + # self.vllm_config.pad_for_cudagraph(num_tokens)) + # else: + # # Eager mode. + # # Pad tokens to multiple of tensor_parallel_size when + # # enabled collective fusion for SP + # tp_size = self.vllm_config.parallel_config.tensor_parallel_size + # if self.vllm_config.compilation_config.pass_config. \ + # enable_sequence_parallelism and tp_size > 1: + # from vllm.utils import round_up + # tokens_slice = slice( + # tokens_slice.start, + # tokens_slice.start + round_up(num_tokens, tp_size)) + + # update num tokens for padding + # num_tokens = tokens_slice.stop - tokens_slice.start + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + + if self.is_multimodal_model and get_pp_group().is_first_rank: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[tokens_slice] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[tokens_slice].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[tokens_slice] + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids[tokens_slice] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, tokens_slice] + else: + positions = self.positions[tokens_slice] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + tokens_slice, intermediate_tensors, True) + return input_ids, positions, inputs_embeds, intermediate_tensors + + def _run_model(self, + attn_metadata: Optional[PerLayerAttnMetadata], + num_scheduled_tokens: Optional[int], + ubatch_slices: Optional[UBatchSlices] = None, + scheduler_output: Optional["SchedulerOutput"] = None, + is_dummy_run: bool = False, + num_tokens_across_dp: Optional[torch.Tensor] = None, + skip_cuda_graphs: bool = False, + build_cuda_graph: bool = False): + + @dataclasses.dataclass + class UbatchMetadata: + context: UBatchContext + input_ids: torch.Tensor + positions: torch.Tensor + inputs_embeds: Optional[torch.Tensor] + intermediate_tensors: Optional[IntermediateTensors] + + + num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1 + + def _make_ubatch_contexts(ubatch_slices, + attn_metadata, + compute_stream, + num_tokens_across_dp, + skip_cuda_graphs) -> list[UBatchContext]: + ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices), + compute_stream=compute_stream, + device=self.device) + + for i, (_, tokens_slice) in enumerate(ubatch_slices): + num_tokens = (tokens_slice.stop - tokens_slice.start) + # TODO (Sage) Instead of using this setter we should be able + # to just create the forward context in advance and pass it + # to the UBatchContext's __init__ method + ubatch_ctxs[i].forward_context = create_forward_context( + attn_metadata[i] + if attn_metadata is not None else None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs) + return ubatch_ctxs + + def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: + if use_dummy_input: + # print("MAKING DUMMY BATCH") + # assert num_dummy_tokens == 1 + return self._get_dummy_model_inputs(num_dummy_tokens) + else: + assert scheduler_output is not None + return self._get_model_inputs(tokens_slice, scheduler_output) + + def _make_ubatch_metadata(ubatch_slices, + attn_metadata, + compute_stream, + is_dummy_run, + num_tokens_across_dp, + skip_cuda_graphs) -> list[UbatchMetadata]: + ubatch_ctxs = _make_ubatch_contexts( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + compute_stream=compute_stream, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs + ) + # First get some inputs + ubatch_metadata: list[UbatchMetadata] = [] + for i, (_, tokens_slice) in enumerate(ubatch_slices): + input_ids, positions, inputs_embeds, intermediate_tensors = \ + model_inputs(tokens_slice, is_dummy_run) + ubatch_metadata.append(UbatchMetadata( + context=ubatch_ctxs[i], + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors + )) + + return ubatch_metadata + + def _run(context, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + start_signal=None): + with context: + if start_signal is not None: + start_signal.wait() + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + if isinstance(context, UBatchContext): + # Clone before we leave the ubatch context + model_output = model_output.clone() + + return model_output + + @torch.inference_mode() + def _ubatch_thread(results, ubatch_metadata, start_signal): + # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) + context = ubatch_metadata.context + with torch.cuda.stream(context.compute_stream): + _ = torch.cuda.current_blas_handle() + with torch.cuda.stream(context.comm_stream): + _ = torch.cuda.current_blas_handle() + model_output = _run(context=ubatch_metadata.context, + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + inputs_embeds=ubatch_metadata.inputs_embeds, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + start_signal=start_signal) + + results.append((ubatch_metadata.context.id, model_output)) + # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) + + def _run_ubatches(ubatch_metadata, num_tokens, should_capture=False) -> torch.Tensor: + results: list[tuple[int, torch.Tensor]] = [] + + # Ubatches will manually manage the forward context, so we override + # it to None here so we can have it restored correctly later + with override_forward_context(None): + ubatch_threads = [] + start_signals = [] + for metadata in ubatch_metadata: + start_signal = threading.Event() + thread = threading.Thread(target=_ubatch_thread, + args=( + results, + metadata, + start_signal, + )) + ubatch_threads.append(thread) + thread.start() + start_signals.append(start_signal) + + # DO capture + cudagraph_metadata = \ + CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + using_ubatching=True + ) + with torch.cuda.graph(cudagraph_metadata.cudagraph, + stream=compute_stream): + # logger.info("STARTING WAKEUP LOOP") + for start_signal in start_signals: + start_signal.set() + # logger.info("FINISHED WAKEUP LOOP") + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + cudagraph_metadata.outputs = result + logger.info(f"Capturing for {num_tokens} tokens") + self.cudagraphs[num_tokens] = cudagraph_metadata + return cudagraph_metadata.outputs + + # run micro-batched + if ubatch_slices is not None: + assert len(ubatch_slices) == 2, "Only two ubatches has been tested" + # num_tokens = ubatch_slices[1][1].stop + print(f"RUNNING UBATCH {ubatch_slices} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") + # assert not is_dummy_run + compute_stream = torch.cuda.Stream(device=self.device) + ubatch_metadata = _make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + compute_stream=compute_stream, + is_dummy_run=is_dummy_run, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs + ) + if num_scheduled_tokens not in self.cudagraphs \ + and not skip_cuda_graphs and build_cuda_graph: + return _run_ubatches(ubatch_metadata, num_scheduled_tokens, should_capture=True) + elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs: + cudagraph_metadata = self.cudagraphs[num_scheduled_tokens] + logger.info(f"UBATCH REPLAY {num_scheduled_tokens}") + cudagraph_metadata.cudagraph.replay() + return cudagraph_metadata.outputs + else: + return _run_ubatches(ubatch_metadata, num_scheduled_tokens) + # run single batch + else: + input_ids, positions, inputs_embeds, intermediate_tensors = \ + model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) + # if num_scheduled_tokens not in self.cudagraphs \ + # and not skip_cuda_graphs and build_cuda_graph: + # assert False + # logger.info(f"GRAPH BUILD{num_scheduled_tokens}") + # self.cudagraphs[num_scheduled_tokens] = \ + # CUDAGraphMetaData( + # cudagraph=torch.cuda.CUDAGraph(), + # using_ubatching=False + # ) + # with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph): + # model_output = _run( + # context = set_forward_context(attn_metadata, + # vllm_config=self.vllm_config, + # num_tokens=num_scheduled_tokens or 1, + # num_tokens_across_dp=num_tokens_across_dp, + # skip_cuda_graphs=skip_cuda_graphs), + # input_ids=input_ids, + # positions=positions, + # inputs_embeds=inputs_embeds, + # intermediate_tensors=intermediate_tensors + # ) + # self.cudagraphs[num_scheduled_tokens].outputs = model_output + # return self.cudagraphs[num_scheduled_tokens].outputs + # elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs: + # assert False + # # logger.info(f"GRAPH REPLAY {num_scheduled_tokens}") + # self.cudagraphs[num_scheduled_tokens].cudagraph.replay() + # return self.cudagraphs[num_scheduled_tokens].outputs + # else: + # logger.info(f"NORMAL RUN {num_scheduled_tokens}") + return _run( + context = set_forward_context(attn_metadata, + vllm_config=self.vllm_config, + num_tokens=num_scheduled_tokens or 1, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs), + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors + ) + def _pool( self, hidden_states: torch.Tensor, @@ -1290,99 +1895,42 @@ def execute_model( return self.kv_connector_no_forward(scheduler_output) + # num_scheduled_tokens_old = scheduler_output.total_num_scheduled_tokens + # num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_scheduled_tokens_old) # Prepare the decoder inputs. - (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) + attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices, num_pad_tokens, num_tokens_after_padding = ( + self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_input_tokens = round_up(num_scheduled_tokens, tp_size) - else: - num_input_tokens = num_scheduled_tokens - - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad - - # _prepare_inputs may reorder the batch, so we must gather multi - # modal outputs after that to ensure the correct order - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - - if self.is_multimodal_model and get_pp_group().is_first_rank: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:num_scheduled_tokens] - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(input_ids) - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) - inputs_embeds = self.inputs_embeds[:num_input_tokens] - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] - else: - positions = self.positions[:num_input_tokens] - - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) + num_input_tokens = num_scheduled_tokens + if ubatch_slices and num_pad_tokens > 0: + num_input_tokens += num_pad_tokens + self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens) + elif ubatch_slices is None: + # logger.info("ATTEMPTING TO PAD NORMAL BATCH") + num_pad, num_tokens_after_padding = self.get_padding(num_input_tokens) + # logger.info("NORMAL BATCH DONE") + num_input_tokens += num_pad # Some attention backends only support CUDA Graphs in pure decode. # If attention doesn't support CUDA Graphs for this batch, but we # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + # logger.info("RUNNING MODEL") # Run the model. # Use persistent buffers for CUDA graphs. - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, - ): - self.maybe_setup_kv_connector(scheduler_output) - - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) + self.maybe_setup_kv_connector(scheduler_output) + model_output = self._run_model( + attn_metadata=attn_metadata, + num_scheduled_tokens=num_input_tokens, + ubatch_slices=ubatch_slices, + scheduler_output=scheduler_output, + num_tokens_across_dp=num_tokens_after_padding, + skip_cuda_graphs=skip_cuda_graphs, + ) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1613,6 +2161,8 @@ def propose_draft_token_ids( if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] + + #TODO(sage) make sure this works with mrope # TODO(woosuk): Support M-RoPE. target_positions = self.positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: @@ -1642,6 +2192,7 @@ def propose_draft_token_ids( num_tokens, ) target_token_ids = self.input_ids[token_indices] + # TODO(sage) make sure this works with mrope # TODO(woosuk): Support M-RoPE. target_positions = self.positions[token_indices] if self.use_aux_hidden_state_outputs: @@ -1666,6 +2217,7 @@ def propose_draft_token_ids( def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. with set_forward_context(None, self.vllm_config): self.maybe_setup_kv_connector(scheduler_output) @@ -1951,15 +2503,42 @@ def rand_input_ids() -> torch.Tensor: def _dummy_run( self, num_tokens: int, + skip_attn: bool = True, + # Maybe return a cudagraph here capture_attn_cudagraph: bool = False, + + # For profiling runs we dont want microbatching but for + # dp dummy runs we do. + allow_microbatching: bool = False, + build_cuda_graph: bool = False, skip_eplb: bool = False, is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: + + # if allow_microbatching: + # logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN") + + + # TODO(Sage) We need some more code to properly handle + # mixing normal and dummy runs. The DP padding needs to + # be properly setup. Since we only support microbatching + # in cuda graph capture it's fine to ignore the DP padding + # for now. + should_ubatch = num_tokens >= \ + self.parallel_config.microbatching_token_threshold and \ + allow_microbatching and capture_attn_cudagraph + # _dummy_run doesn't go through _prepare_inputs so + # we synchronize with other DP ranks here + # logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}") + should_ubatch = self.should_ubatch(should_ubatch) # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + # logger.info("PADDING DUMMY") + num_tokens_across_dp = None + num_pad = 0 + if not should_ubatch: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad - # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. @@ -1974,73 +2553,91 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - attn_metadata: Optional[dict[str, Any]] = None + ubatch_slices = None + # We currently only microbatch if the number of tokens is + # over a certain threshold. + # logger.info("PADDING DUMMY DONE") + if should_ubatch: + # We only support decode-only cudagraphs + assert num_reqs == num_tokens + assert num_tokens % 2 == 0 + num_tokens_per_ubatch = num_tokens // 2 + num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] * 2, + device="cpu", + dtype=torch.int32) + ubatch_slices = [(slice(0, num_reqs // 2), + slice(0, num_tokens // 2)), + (slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens))] + + + # attn_metadata: Optional[dict[str, Any]] = None + attn_metadata: Optional[PerLayerAttnMetadata]= None if capture_attn_cudagraph: attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] - query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. self.seq_lens_np[:num_reqs] = self.max_model_len self.seq_lens_np[num_reqs:] = 0 self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - seq_lens = self.seq_lens[:num_reqs] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - ) - for kv_cache_group_id, kv_cache_group_spec in enumerate( + max_query_len = num_tokens + if ubatch_slices is not None: + max_query_len = 1 + for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + block_table_tensor=self.input_batch.block_table[kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch.block_table[kv_cache_group_id].slot_mapping[:num_reqs], + slot_mapping_cpu=self.input_batch.block_table[kv_cache_group_id].slot_mapping_cpu[:num_reqs], + ) + + if ubatch_slices is not None: + for ubid, (req_slice, token_slice) in enumerate(ubatch_slices): + # Run a dummy batch if its a empty ubatch + if token_slice.stop <= token_slice.start: + attn_metadata_i = None + else: + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id]. + build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata._slice(token_slice, cg_buffer_idx=ubid), + ubatch_id=ubid + )) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + # assert attn_metadata_i is not None + # What if it's None? Do we still add it to the list? + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + attn_metadata_i = self.attn_metadata_builders[ + kv_cache_group_id].build_for_cudagraph_capture( + common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( - common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - model = self.model - if self.is_multimodal_model: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] - else: - positions = self.positions[:num_tokens] - - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device)) - - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) - - with self.maybe_randomize_inputs(input_ids), set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): - outputs = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + outputs = self._run_model( + attn_metadata, + num_tokens, + ubatch_slices=ubatch_slices, + is_dummy_run=True, + num_tokens_across_dp=num_tokens_across_dp, + build_cuda_graph=build_cuda_graph + ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -2255,6 +2852,7 @@ def profile_run(self) -> None: # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ = self._dummy_run(self.max_num_tokens, is_profile=True) + if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -2280,24 +2878,31 @@ def capture_model(self) -> None: start_time = time.perf_counter() start_free_gpu_memory = torch.cuda.mem_get_info()[0] + logger.info("CAPTURE MODEL START") # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): full_cg = self.full_cuda_graph + allow_microbatching = self.parallel_config.enable_microbatching for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), desc="Capturing CUDA graphs", total=len(self.cudagraph_batch_sizes)): # We skip EPLB here since we don't want to record dummy metrics for _ in range( self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, + allow_microbatching=allow_microbatching, + build_cuda_graph=True, skip_eplb=True) - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, + allow_microbatching=allow_microbatching, + build_cuda_graph=True, skip_eplb=True) + logger.info("CAPTURE MODEL END") end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] elapsed_time = end_time - start_time diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 9e7e44d06861..e02a22093b29 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -328,6 +328,7 @@ def profile(self, is_start: bool = True): sort_by="self_cuda_time_total")) def execute_dummy_batch(self) -> None: + # TODO: adding allow_microbatching will break non-gpu backends self.model_runner._dummy_run(1) def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py new file mode 100644 index 000000000000..3defe34e06bf --- /dev/null +++ b/vllm/v1/worker/ubatching.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +import threading +from typing import Optional + +import torch +import torch._dynamo +from torch.library import custom_op + +from vllm import forward_context +from vllm.utils import current_stream +from vllm.distributed.parallel_state import get_dp_group + + +class UBatchContext: + """ + Context manager for micro-batching synchronization using threading events. + """ + + def __init__( + self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + #fwd_ctx: forward_context.ForwardContext, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default"): + self.id = id + self.comm_stream = comm_stream + self.compute_stream = compute_stream + self.forward_context = None #fwd_ctx + self.cpu_wait_event = cpu_wait_event + self.cpu_signal_event = cpu_signal_event + self.current_stream = compute_stream + self.gpu_comm_done_event = gpu_comm_done_event + self.gpu_compute_done_event = gpu_compute_done_event + self.schedule = schedule + + def __enter__(self): + global _CURRENT_CONTEXT + _CURRENT_CONTEXT[threading.get_ident()] = self + + self.cpu_wait_event.clear() + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + # Assume we start on the compute stream + assert current_stream() == self.compute_stream + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _CURRENT_CONTEXT + _CURRENT_CONTEXT[threading.get_ident()] = None + # print("Finishing ubatch %d\n" % self.id, flush=True) + self.cpu_signal_event.set() + self.cpu_wait_event.clear() + self.current_stream = self.compute_stream + torch.cuda.set_stream(self.current_stream) + return False + + def _restore_context(self): + forward_context._forward_context = self.forward_context + torch.cuda.set_stream(self.current_stream) + + def update_stream(self, stream): + self.current_stream = stream + torch.cuda.set_stream(self.current_stream) + + def ctx_valid_state(self): + assert forward_context._forward_context == self.forward_context + assert current_stream() == self.current_stream + assert not self.cpu_wait_event.is_set() + pass + + def _signal_comm_done(self): + # assert False + self.ctx_valid_state() + self.gpu_comm_done_event.record(self.comm_stream) + + def _signal_compute_done(self): + # assert False + self.ctx_valid_state() + self.gpu_compute_done_event.record(self.compute_stream) + + def _wait_compute_done(self): + # assert False + # print(f"{self.id} Waiting on COMPUTE stream", flush=True) + self.ctx_valid_state() + self.comm_stream.wait_event(self.gpu_compute_done_event) + # print("Compute stream done", flush=True) + + def _wait_comm_done(self): + # assert False + # print(f"{self.id} Waiting on COMM stream", flush=True) + self.ctx_valid_state() + self.compute_stream.wait_event(self.gpu_comm_done_event) + # print("Comm stream done", flush=True) + + def stream_string(self): + # assert False + if current_stream() == self.compute_stream: + assert self.current_stream == self.compute_stream + return "COMPUTE" + elif current_stream() == self.comm_stream: + assert self.current_stream == self.comm_stream + return "COMM" + + def _cpu_yield(self): + # print(f"UBatchContext: {self.id} yielding CPU", flush=True) + self.ctx_valid_state() + self.cpu_signal_event.set() + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + self.ctx_valid_state() + # print(f"UBatchContext: {self.id} resuming CPU", flush=True) + + def yield_and_switch_from_compute_to_comm(self): + # assert False + assert current_stream() == self.compute_stream + # dp_rank = get_dp_group().rank_in_group + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Yield and switch from {self.stream_string()}", flush=True) + self.ctx_valid_state() + self._signal_compute_done() + self._cpu_yield() + self.ctx_valid_state() + assert self.current_stream == self.compute_stream + self.update_stream(self.comm_stream) + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Resuming on stream {self.stream_string()}", flush=True) + self._wait_compute_done() + + def yield_and_switch_from_comm_to_compute(self): + # assert False + assert current_stream() == self.comm_stream + # dp_rank = get_dp_group().rank_in_group + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Yield and switch from {self.stream_string()}", flush=True) + self.ctx_valid_state() + self._signal_comm_done() + self._cpu_yield() + self.ctx_valid_state() + assert self.current_stream == self.comm_stream + self.update_stream(self.compute_stream) + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Resuming on stream {self.stream_string()}", flush=True) + self._wait_comm_done() + + +_CURRENT_CONTEXT: dict = {} + + +def get_current_ubatch_context() -> Optional[UBatchContext]: + global _CURRENT_CONTEXT + """ + Get the current UBatchContext for the current thread. + """ + return _CURRENT_CONTEXT.get(threading.get_ident(), None) + + +def yield_and_switch_from_compute_to_comm_impl(schedule="default"): + # Perform the barrier if a context exists for this thread + ctx = get_current_ubatch_context() + #print("you are in yield_impl", ctx) + if ctx is not None and ctx.schedule == schedule: + ctx.yield_and_switch_from_compute_to_comm() + + +def yield_and_switch_from_comm_to_compute_impl(schedule="default"): + # Perform the barrier if a context exists for this thread + ctx = get_current_ubatch_context() + if ctx is not None and ctx.schedule == schedule: + ctx.yield_and_switch_from_comm_to_compute() + + +# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from +# optimizing it away (TODO: see if this is actually needed) +@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", )) +def yield_and_switch_from_compute_to_comm(x: torch.Tensor, + schedule: str = "default") -> None: + yield_and_switch_from_compute_to_comm_impl(schedule) + + +# 3) Fake implementation for shape prop and FX tracing +@yield_and_switch_from_compute_to_comm.register_fake +def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor, + schedule: str = "default" + ) -> None: + pass + + +@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", )) +def yield_and_switch_from_comm_to_compute(x: torch.Tensor, + schedule: str = "default") -> None: + yield_and_switch_from_comm_to_compute_impl(schedule) + + +@yield_and_switch_from_comm_to_compute.register_fake +def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor, + schedule: str = "default" + ) -> None: + pass + + +def dump_ubatching_state(): + pass + + +""" +""" + + +def make_ubatch_contexts( + num_micro_batches: int, + compute_stream: torch.cuda.Stream, + device: Optional[torch.device] = None, + schedule: str = "default", +) -> list[UBatchContext]: + assert num_micro_batches == 2, "only been tested with 2 micro-batches" + """ + Create a context manager for micro-batching synchronization. + """ + cpu_events = [threading.Event() for _ in range(num_micro_batches)] + gpu_comm_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] + gpu_compute_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] + device = device or torch.cuda.current_device() + comm_stream = torch.cuda.Stream(device) + + ctxs = [] + for i in range(num_micro_batches): + ctx = UBatchContext(id=i, + compute_stream=compute_stream, + comm_stream=comm_stream, + cpu_wait_event=cpu_events[i], + cpu_signal_event=cpu_events[(i + 1) % + num_micro_batches], + gpu_comm_done_event=gpu_comm_done_events[i], + gpu_compute_done_event=gpu_compute_done_events[i], + schedule=schedule) + ctxs.append(ctx) + + return ctxs