Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ async def generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=kwargs.pop("top_k", 0),
min_p=kwargs.pop("min_p", 0.0),
presence_penalty=kwargs.pop("presence_penalty", 0.0),
repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
)

return GenerationOutput(
Expand All @@ -480,6 +484,10 @@ async def generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=kwargs.pop("top_k", 0),
min_p=kwargs.pop("min_p", 0.0),
presence_penalty=kwargs.pop("presence_penalty", 0.0),
repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
stop=stop or [],
)

Expand Down Expand Up @@ -536,6 +544,10 @@ async def stream_generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=kwargs.pop("top_k", 0),
min_p=kwargs.pop("min_p", 0.0),
presence_penalty=kwargs.pop("presence_penalty", 0.0),
repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
)

async for output in self._mllm_scheduler.stream_outputs(request_id):
Expand All @@ -556,6 +568,10 @@ async def stream_generate(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=kwargs.pop("top_k", 0),
min_p=kwargs.pop("min_p", 0.0),
presence_penalty=kwargs.pop("presence_penalty", 0.0),
repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
stop=stop or [],
)

Expand Down
118 changes: 114 additions & 4 deletions vllm_mlx/mllm_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class MLLMBatchRequest:
max_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.9
top_k: int = 0
min_p: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.0

# Processed inputs (set after vision preprocessing)
input_ids: Optional[mx.array] = None
Expand Down Expand Up @@ -98,6 +102,8 @@ class MLLMBatch:
num_tokens: List[int] # Tokens generated per request
cache: List[Any] # BatchKVCache for language model
requests: List[MLLMBatchRequest] # Full request data
logits_processors: Optional[List[Optional[List[Callable]]]] = None
samplers: Optional[List[Optional[Callable]]] = None

def __len__(self) -> int:
return len(self.uids)
Expand All @@ -115,6 +121,10 @@ def filter(self, keep_idx: List[int]) -> None:
self.max_tokens = [self.max_tokens[k] for k in keep_idx]
self.num_tokens = [self.num_tokens[k] for k in keep_idx]
self.requests = [self.requests[k] for k in keep_idx]
if self.logits_processors is not None:
self.logits_processors = [self.logits_processors[k] for k in keep_idx]
if self.samplers is not None:
self.samplers = [self.samplers[k] for k in keep_idx]

keep_idx_array = mx.array(keep_idx, mx.int32)
self.y = self.y[keep_idx_array]
Expand All @@ -139,6 +149,20 @@ def extend(self, other: "MLLMBatch") -> None:
self.max_tokens.extend(other.max_tokens)
self.requests.extend(other.requests)

# Extend logits_processors
if self.logits_processors is not None or other.logits_processors is not None:
self_len = len(self.uids) - len(other.uids)
self_lp = self.logits_processors or [None] * self_len
other_lp = other.logits_processors or [None] * len(other.uids)
self.logits_processors = list(self_lp) + list(other_lp)

# Extend samplers
if self.samplers is not None or other.samplers is not None:
self_len = len(self.uids) - len(other.uids)
self_s = self.samplers or [None] * self_len
other_s = other.samplers or [None] * len(other.uids)
self.samplers = list(self_s) + list(other_s)

# Extend cache - handle None and incompatible caches
for c, o in zip(self.cache, other.cache):
if c is not None and o is not None and hasattr(c, "extend"):
Expand Down Expand Up @@ -692,6 +716,51 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
# Create initial y (first generated tokens)
y = mx.array(first_tokens)

# Build per-request logits processors (repetition_penalty, presence_penalty)
from mlx_lm.sample_utils import make_logits_processors, make_sampler

batch_logits_processors = []
has_any_lp = False
for req in requests:
need_rep = req.repetition_penalty and req.repetition_penalty != 1.0
need_pres = req.presence_penalty and req.presence_penalty != 0.0
if need_rep or need_pres:
lp_kwargs = {}
if need_rep:
lp_kwargs["repetition_penalty"] = req.repetition_penalty
if need_pres:
lp_kwargs["presence_penalty"] = req.presence_penalty
lp = make_logits_processors(**lp_kwargs)
batch_logits_processors.append(lp)
has_any_lp = True
logger.info(
f"[sampling] request={req.request_id[:12]} "
f"rep_penalty={req.repetition_penalty} "
f"pres_penalty={req.presence_penalty}"
)
else:
batch_logits_processors.append(None)

# Build per-request samplers for top_k/min_p
batch_samplers = []
has_any_sampler = False
for req in requests:
if req.top_k != 0 or req.min_p != 0.0:
s = make_sampler(
temp=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
min_p=req.min_p,
)
batch_samplers.append(s)
has_any_sampler = True
logger.info(
f"[sampling] request={req.request_id[:12]} "
f"top_k={req.top_k} min_p={req.min_p}"
)
else:
batch_samplers.append(None)

self._stats.prompt_time += time.perf_counter() - tic

return MLLMBatch(
Expand All @@ -703,17 +772,27 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
num_tokens=[0] * len(requests),
cache=batch_cache,
requests=requests,
logits_processors=batch_logits_processors if has_any_lp else None,
samplers=batch_samplers if has_any_sampler else None,
)

def _step(
self, input_tokens: mx.array, cache: List[Any]
self,
input_tokens: mx.array,
cache: List[Any],
logits_processors: Optional[List[Optional[List[Callable]]]] = None,
output_tokens: Optional[List[List[int]]] = None,
samplers: Optional[List[Optional[Callable]]] = None,
) -> Tuple[mx.array, List[mx.array]]:
"""
Run one generation step through the language model.

Args:
input_tokens: Input tokens [batch_size, 1] or [batch_size]
cache: BatchKVCache for the language model
logits_processors: Per-request logits processors (e.g. repetition penalty)
output_tokens: Per-request generated tokens so far (needed by processors)
samplers: Per-request sampler functions (for top_k/min_p)

Returns:
Tuple of (sampled tokens, logprobs list)
Expand All @@ -733,9 +812,29 @@ def _step(

logits = logits[:, -1, :]

# Sample
# Apply per-request logits processors (repetition penalty etc.)
if logits_processors and output_tokens and any(logits_processors):
processed_logits = []
for e in range(logits.shape[0]):
sample_logits = logits[e : e + 1]
if logits_processors[e]:
for processor in logits_processors[e]:
sample_logits = processor(
mx.array(output_tokens[e]), sample_logits
)
processed_logits.append(sample_logits)
logits = mx.concatenate(processed_logits, axis=0)

# Sample — per-request samplers for top_k/min_p support
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
sampled = self.sampler(logprobs)
if samplers and any(samplers):
sampled_list = []
for e in range(logprobs.shape[0]):
s = samplers[e] if samplers[e] else self.sampler
sampled_list.append(s(logprobs[e : e + 1]))
sampled = mx.concatenate(sampled_list, axis=0)
else:
sampled = self.sampler(logprobs)

return sampled, list(logprobs)

Expand Down Expand Up @@ -775,7 +874,18 @@ def _next(self) -> List[MLLMBatchResponse]:
return []

y, logprobs = batch.y, batch.logprobs
batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
output_tokens = (
[req.output_tokens for req in batch.requests]
if batch.logits_processors
else None
)
batch.y, batch.logprobs = self._step(
y[:, None],
batch.cache,
batch.logits_processors,
output_tokens,
batch.samplers,
)
mx.async_eval(batch.y, batch.logprobs)

y = y.tolist()
Expand Down
8 changes: 8 additions & 0 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def add_request(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=kwargs.pop("top_k", 0),
min_p=kwargs.pop("min_p", 0.0),
presence_penalty=kwargs.pop("presence_penalty", 0.0),
repetition_penalty=kwargs.pop("repetition_penalty", 1.0),
)

request = MLLMRequest(
Expand Down Expand Up @@ -403,6 +407,10 @@ def _schedule_waiting(self) -> List[MLLMRequest]:
max_tokens=request.sampling_params.max_tokens,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
top_k=request.sampling_params.top_k,
min_p=request.sampling_params.min_p,
presence_penalty=request.sampling_params.presence_penalty,
repetition_penalty=request.sampling_params.repetition_penalty,
)
batch_requests.append(batch_req)

Expand Down
1 change: 1 addition & 0 deletions vllm_mlx/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class SamplingParams:
top_p: float = 0.9
top_k: int = 0 # 0 means disabled
min_p: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.0
stop: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
Expand Down
Loading