Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions mlx_lm/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,24 @@ curl localhost:8080/v1/chat/completions \
- `min_p`: (Optional) A float specifying the min-p sampling parameter.
Defaults to `0.0` (disabled).

- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
Defaults to `1.0`.
- `repetition_penalty`: (Optional) Applies a multiplicative penalty to repeated
tokens. Defaults to `0.0` (disabled).

- `repetition_context_size`: (Optional) The size of the context window for
applying repetition penalty. Defaults to `20`.

- `presence_penalty`: (Optional) Applies an additive penalty to tokens
that appeared before. Defaults to `0.0` (disabled).

- `presence_context_size`: (Optional) The size of the context window for
applying presence penalty. Defaults to `20`.

- `frequency_penalty`: (Optional) Applies an additive penalty proportional to
how many times a token appeared previously. Defaults to `0.0` (disabled).

- `frequency_context_size`: (Optional) The size of the context window for
applying frequency penalty. Defaults to `20`.

- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
values. Defaults to `None`.

Expand Down
89 changes: 81 additions & 8 deletions mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,28 @@ def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
presence_penalty: Optional[float] = None,
presence_context_size: Optional[int] = 20,
frequency_penalty: Optional[float] = None,
frequency_context_size: Optional[int] = 20,
):
"""
Make logits processors for use with ``generate_step``.

Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_penalty (float, optional): A (sign-aware) multiplicative
penalty for repeating tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
presence_penalty (float, optional): An additive penalty to reduce
repeating tokens.
presence_context_size (int, optional): The number of tokens to consider
for the presence penalty. Default: ``20``.
frequency_penalty (float, optional): An additive penalty to reduce
repeating tokens. The tokens are penalized proportionally to their
frequency.
frequency_context_size (int, optional): The number of tokens to consider
for the frequency penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.

Returns:
Expand All @@ -96,15 +109,20 @@ def make_logits_processors(
values = mx.array(list(logit_bias.values()))

def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
return logits.at[:, indices].add(values)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, you mentioned once that .at[].add() is needed to use Scatter::Sum instead of Gather + Add + Scatter::None, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly what you wrote. The reason I changed it is because the pattern Gather -> Add -> Scatter can be significantly more inefficient when the Gather or Add depends on the src of scatter because it breaks donation and Scatter ends up doing a copy.


logits_processors.append(logit_bias_processor)

if repetition_penalty and repetition_penalty != 0.0:
logits_processors.append(
make_repetition_penalty(repetition_penalty, repetition_context_size)
)
repetition_penalties = [
(make_repetition_penalty, repetition_penalty, repetition_context_size),
(make_presence_penalty, presence_penalty, presence_context_size),
(make_frequency_penalty, frequency_penalty, frequency_context_size),
]

for make_penalty, penalty, context_size in repetition_penalties:
if penalty is not None and penalty != 0:
logits_processors.append(make_penalty(penalty, context_size))

return logits_processors


Expand Down Expand Up @@ -307,3 +325,58 @@ def repetition_penalty_processor(tokens, logits):
return logits

return repetition_penalty_processor


def make_presence_penalty(penalty: float, context_size: int = 20):
"""
Make a presence penalty processor.

Corresponds to the OpenAI option with the same name. Namely, subtracts
``penalty`` from a logit if the token has occured at least once in the
``context_size`` previous tokens.

Args:
penalty (float): The presence penalty to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.

Returns:
Callable[[mx.array, List[int]], mx.array]
"""

def presence_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
logits[:, tokens] -= penalty
return logits

return presence_penalty_processor


def make_frequency_penalty(penalty: float, context_size: int = 20):
"""
Make a frequency penalty processor.

Corresponds to the OpenAI option with the same name. Namely, subtracts
``penalty`` from a logit for every time that the token has occured in the
``context_size`` previous tokens.

The difference with the presence penalty is that the more often a token
occurs the more it will be penalized.

Args:
penalty (float): The frequency penalty to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.

Returns:
Callable[[mx.array, List[int]], mx.array]
"""

def frequency_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
logits = logits.at[:, tokens].subtract(penalty)
return logits

return frequency_penalty_processor
41 changes: 35 additions & 6 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ class LogitsProcessorArguments:
logit_bias: Optional[Dict[int, float]]
repetition_penalty: float
repetition_context_size: int
presence_penalty: float
presence_context_size: int
frequency_penalty: float
frequency_context_size: int


@dataclass
Expand Down Expand Up @@ -623,6 +627,10 @@ def _make_logits_processors(args):
args.logits.logit_bias,
args.logits.repetition_penalty,
args.logits.repetition_context_size,
args.logits.presence_penalty,
args.logits.presence_context_size,
args.logits.frequency_penalty,
args.logits.frequency_context_size,
)


Expand Down Expand Up @@ -1208,6 +1216,10 @@ def do_POST(self):
self.min_p = self.body.get("min_p", self.response_generator.cli_args.min_p)
self.repetition_penalty = self.body.get("repetition_penalty", 0.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
self.presence_penalty = self.body.get("presence_penalty", 0.0)
self.presence_context_size = self.body.get("presence_context_size", 20)
self.frequency_penalty = self.body.get("frequency_penalty", 0.0)
self.frequency_context_size = self.body.get("frequency_context_size", 20)
self.xtc_probability = self.body.get("xtc_probability", 0.0)
self.xtc_threshold = self.body.get("xtc_threshold", 0.0)
self.logit_bias = self.body.get("logit_bias", None)
Expand Down Expand Up @@ -1256,6 +1268,25 @@ def validate_model_parameters(self):
or self.repetition_penalty < 0
):
raise ValueError("repetition_penalty must be a non-negative float")
if (
not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0
):
raise ValueError("repetition_context_size must be a non-negative integer")
if not isinstance(self.presence_penalty, (float, int)):
raise ValueError("Presence penalty must be must be a float")
if (
not isinstance(self.presence_context_size, int)
or self.presence_context_size < 0
):
raise ValueError("presence_context_size must be a non-negative integer")
if not isinstance(self.frequency_penalty, (float, int)):
raise ValueError("Presence penalty must be must be a float")
if (
not isinstance(self.frequency_context_size, int)
or self.frequency_context_size < 0
):
raise ValueError("frequency_context_size must be a non-negative integer")

if not isinstance(self.logprobs, bool):
raise ValueError("logprobs must be a boolean")
Expand All @@ -1265,12 +1296,6 @@ def validate_model_parameters(self):
f"top_logprobs must be between 1 and 10 but got {self.top_logprobs:,}"
)

if (
not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0
):
raise ValueError("repetition_context_size must be a non-negative integer")

if self.logit_bias is not None:
if not isinstance(self.logit_bias, dict):
raise ValueError("logit_bias must be a dict of int to float")
Expand Down Expand Up @@ -1430,6 +1455,10 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]):
logit_bias=self.logit_bias,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
presence_penalty=self.presence_penalty,
presence_context_size=self.presence_context_size,
frequency_penalty=self.frequency_penalty,
frequency_context_size=self.frequency_context_size,
),
stop_words=stop_words,
max_tokens=self.max_tokens,
Expand Down
58 changes: 58 additions & 0 deletions tests/test_sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,64 @@ def test_apply_xtc(self):
new_probs = mx.softmax(apply_xtc(mx.log(probs), 0, 0.1, [0]), -1)
self.assertTrue(mx.allclose(new_probs, probs))

def test_presence_penalty(self):
from mlx_lm.sample_utils import make_presence_penalty

# Token appears multiple times - penalty applied once
tokens = mx.array([0, 0, 0, 1, 1])
logits = mx.zeros((1, 4))
processor = make_presence_penalty(0.5, context_size=5)
result = processor(tokens, logits)
# Token 0 appears 3 times, token 1 appears 2 times - both penalized once
self.assertAlmostEqual(result[0, 0].item(), -0.5)
self.assertAlmostEqual(result[0, 1].item(), -0.5)
# Tokens not in context not penalized
self.assertAlmostEqual(result[0, 2].item(), 0.0)
self.assertAlmostEqual(result[0, 3].item(), 0.0)

def test_frequency_penalty(self):
from mlx_lm.sample_utils import make_frequency_penalty

# Token appears multiple times - penalty applied proportionally
tokens = mx.array([0, 0, 0, 1, 1])
logits = mx.zeros((1, 4))
processor = make_frequency_penalty(0.5, context_size=5)
result = processor(tokens, logits)
# Token 0 appears 3 times -> 3 * 0.5 = 1.5 penalty
self.assertAlmostEqual(result[0, 0].item(), -1.5)
# Token 1 appears 2 times -> 2 * 0.5 = 1.0 penalty
self.assertAlmostEqual(result[0, 1].item(), -1.0)
# Tokens not in context not penalized
self.assertAlmostEqual(result[0, 2].item(), 0.0)
self.assertAlmostEqual(result[0, 3].item(), 0.0)

def test_make_logits_processors(self):
from mlx_lm.sample_utils import make_logits_processors

# Create processors with all three penalty types
tokens = mx.array([0, 0, 0, 1, 1])
# Use non-zero logits so repetition penalty has effect
logits = mx.array([[1.0, 0.5, 0.0, -0.5]])
processors = make_logits_processors(
repetition_penalty=1.5,
repetition_context_size=5,
presence_penalty=0.5,
presence_context_size=5,
frequency_penalty=0.25,
frequency_context_size=5,
)
# Apply all processors
for processor in processors:
logits = processor(tokens, logits)
# Token 0 (appears 3x): 1.0/1.5 - 0.5 - 0.75 = -0.5833
# Token 1 (appears 2x): 0.5/1.5 - 0.5 - 0.5 = -0.6667
# Token 2 (not in context): 0.0 (no penalty)
# Token 3 (not in context): -0.5 (no penalty)
self.assertAlmostEqual(logits[0, 0].item(), -0.5833, places=4)
self.assertAlmostEqual(logits[0, 1].item(), -0.6667, places=4)
self.assertAlmostEqual(logits[0, 2].item(), 0.0, places=4)
self.assertAlmostEqual(logits[0, 3].item(), -0.5, places=4)


if __name__ == "__main__":
unittest.main()
Loading