From 4116542cf2f3e6a6f903965b3d153d7136e7ee8d Mon Sep 17 00:00:00 2001 From: Jan Hilgard Date: Mon, 6 Apr 2026 12:26:08 +0200 Subject: [PATCH] feat: add full sampling params support for MLLM continuous batching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends MLLM batch generator to support top_k, min_p, and presence_penalty alongside the existing repetition_penalty. This gives the MLLM path full parity with the LLM/SimpleEngine sampling parameter coverage. Changes: - MLLMBatchRequest: add top_k, min_p, presence_penalty fields - MLLMBatch: add per-request samplers list (filter/extend support) - _process_prompts: build per-request logits processors for presence_penalty and per-request samplers for top_k/min_p - _step: accept and apply per-request samplers - SamplingParams: add presence_penalty field - MLLMScheduler: propagate new params from kwargs to batch requests - BatchedEngine: pass new params through generate/stream_generate When a request uses default values (top_k=0, min_p=0.0, presence_penalty=0.0), no extra processors or samplers are created — zero overhead for standard requests. Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/engine/batched.py | 16 +++++ vllm_mlx/mllm_batch_generator.py | 118 +++++++++++++++++++++++++++++-- vllm_mlx/mllm_scheduler.py | 8 +++ vllm_mlx/request.py | 1 + 4 files changed, 139 insertions(+), 4 deletions(-) diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 3ac52b4b0..d71cc108a 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -464,6 +464,10 @@ async def generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) return GenerationOutput( @@ -480,6 +484,10 @@ async def generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], ) @@ -536,6 +544,10 @@ async def stream_generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) async for output in self._mllm_scheduler.stream_outputs(request_id): @@ -556,6 +568,10 @@ async def stream_generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], ) diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index ee8d8da7b..9c612e842 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -47,6 +47,10 @@ class MLLMBatchRequest: max_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 + top_k: int = 0 + min_p: float = 0.0 + presence_penalty: float = 0.0 + repetition_penalty: float = 1.0 # Processed inputs (set after vision preprocessing) input_ids: Optional[mx.array] = None @@ -98,6 +102,8 @@ class MLLMBatch: num_tokens: List[int] # Tokens generated per request cache: List[Any] # BatchKVCache for language model requests: List[MLLMBatchRequest] # Full request data + logits_processors: Optional[List[Optional[List[Callable]]]] = None + samplers: Optional[List[Optional[Callable]]] = None def __len__(self) -> int: return len(self.uids) @@ -115,6 +121,10 @@ def filter(self, keep_idx: List[int]) -> None: self.max_tokens = [self.max_tokens[k] for k in keep_idx] self.num_tokens = [self.num_tokens[k] for k in keep_idx] self.requests = [self.requests[k] for k in keep_idx] + if self.logits_processors is not None: + self.logits_processors = [self.logits_processors[k] for k in keep_idx] + if self.samplers is not None: + self.samplers = [self.samplers[k] for k in keep_idx] keep_idx_array = mx.array(keep_idx, mx.int32) self.y = self.y[keep_idx_array] @@ -139,6 +149,20 @@ def extend(self, other: "MLLMBatch") -> None: self.max_tokens.extend(other.max_tokens) self.requests.extend(other.requests) + # Extend logits_processors + if self.logits_processors is not None or other.logits_processors is not None: + self_len = len(self.uids) - len(other.uids) + self_lp = self.logits_processors or [None] * self_len + other_lp = other.logits_processors or [None] * len(other.uids) + self.logits_processors = list(self_lp) + list(other_lp) + + # Extend samplers + if self.samplers is not None or other.samplers is not None: + self_len = len(self.uids) - len(other.uids) + self_s = self.samplers or [None] * self_len + other_s = other.samplers or [None] * len(other.uids) + self.samplers = list(self_s) + list(other_s) + # Extend cache - handle None and incompatible caches for c, o in zip(self.cache, other.cache): if c is not None and o is not None and hasattr(c, "extend"): @@ -692,6 +716,51 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: # Create initial y (first generated tokens) y = mx.array(first_tokens) + # Build per-request logits processors (repetition_penalty, presence_penalty) + from mlx_lm.sample_utils import make_logits_processors, make_sampler + + batch_logits_processors = [] + has_any_lp = False + for req in requests: + need_rep = req.repetition_penalty and req.repetition_penalty != 1.0 + need_pres = req.presence_penalty and req.presence_penalty != 0.0 + if need_rep or need_pres: + lp_kwargs = {} + if need_rep: + lp_kwargs["repetition_penalty"] = req.repetition_penalty + if need_pres: + lp_kwargs["presence_penalty"] = req.presence_penalty + lp = make_logits_processors(**lp_kwargs) + batch_logits_processors.append(lp) + has_any_lp = True + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"rep_penalty={req.repetition_penalty} " + f"pres_penalty={req.presence_penalty}" + ) + else: + batch_logits_processors.append(None) + + # Build per-request samplers for top_k/min_p + batch_samplers = [] + has_any_sampler = False + for req in requests: + if req.top_k != 0 or req.min_p != 0.0: + s = make_sampler( + temp=req.temperature, + top_p=req.top_p, + top_k=req.top_k, + min_p=req.min_p, + ) + batch_samplers.append(s) + has_any_sampler = True + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"top_k={req.top_k} min_p={req.min_p}" + ) + else: + batch_samplers.append(None) + self._stats.prompt_time += time.perf_counter() - tic return MLLMBatch( @@ -703,10 +772,17 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: num_tokens=[0] * len(requests), cache=batch_cache, requests=requests, + logits_processors=batch_logits_processors if has_any_lp else None, + samplers=batch_samplers if has_any_sampler else None, ) def _step( - self, input_tokens: mx.array, cache: List[Any] + self, + input_tokens: mx.array, + cache: List[Any], + logits_processors: Optional[List[Optional[List[Callable]]]] = None, + output_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Optional[Callable]]] = None, ) -> Tuple[mx.array, List[mx.array]]: """ Run one generation step through the language model. @@ -714,6 +790,9 @@ def _step( Args: input_tokens: Input tokens [batch_size, 1] or [batch_size] cache: BatchKVCache for the language model + logits_processors: Per-request logits processors (e.g. repetition penalty) + output_tokens: Per-request generated tokens so far (needed by processors) + samplers: Per-request sampler functions (for top_k/min_p) Returns: Tuple of (sampled tokens, logprobs list) @@ -733,9 +812,29 @@ def _step( logits = logits[:, -1, :] - # Sample + # Apply per-request logits processors (repetition penalty etc.) + if logits_processors and output_tokens and any(logits_processors): + processed_logits = [] + for e in range(logits.shape[0]): + sample_logits = logits[e : e + 1] + if logits_processors[e]: + for processor in logits_processors[e]: + sample_logits = processor( + mx.array(output_tokens[e]), sample_logits + ) + processed_logits.append(sample_logits) + logits = mx.concatenate(processed_logits, axis=0) + + # Sample — per-request samplers for top_k/min_p support logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - sampled = self.sampler(logprobs) + if samplers and any(samplers): + sampled_list = [] + for e in range(logprobs.shape[0]): + s = samplers[e] if samplers[e] else self.sampler + sampled_list.append(s(logprobs[e : e + 1])) + sampled = mx.concatenate(sampled_list, axis=0) + else: + sampled = self.sampler(logprobs) return sampled, list(logprobs) @@ -775,7 +874,18 @@ def _next(self) -> List[MLLMBatchResponse]: return [] y, logprobs = batch.y, batch.logprobs - batch.y, batch.logprobs = self._step(y[:, None], batch.cache) + output_tokens = ( + [req.output_tokens for req in batch.requests] + if batch.logits_processors + else None + ) + batch.y, batch.logprobs = self._step( + y[:, None], + batch.cache, + batch.logits_processors, + output_tokens, + batch.samplers, + ) mx.async_eval(batch.y, batch.logprobs) y = y.tolist() diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 555b230f2..d2c7b4bb5 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -297,6 +297,10 @@ def add_request( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) request = MLLMRequest( @@ -403,6 +407,10 @@ def _schedule_waiting(self) -> List[MLLMRequest]: max_tokens=request.sampling_params.max_tokens, temperature=request.sampling_params.temperature, top_p=request.sampling_params.top_p, + top_k=request.sampling_params.top_k, + min_p=request.sampling_params.min_p, + presence_penalty=request.sampling_params.presence_penalty, + repetition_penalty=request.sampling_params.repetition_penalty, ) batch_requests.append(batch_req) diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py index 41679c0ba..f18b238d8 100644 --- a/vllm_mlx/request.py +++ b/vllm_mlx/request.py @@ -57,6 +57,7 @@ class SamplingParams: top_p: float = 0.9 top_k: int = 0 # 0 means disabled min_p: float = 0.0 + presence_penalty: float = 0.0 repetition_penalty: float = 1.0 stop: Optional[List[str]] = None stop_token_ids: Optional[List[int]] = None