diff --git a/mlx_lm/SERVER.md b/mlx_lm/SERVER.md index 37fc3da7e..f38ad3dd4 100644 --- a/mlx_lm/SERVER.md +++ b/mlx_lm/SERVER.md @@ -72,12 +72,24 @@ curl localhost:8080/v1/chat/completions \ - `min_p`: (Optional) A float specifying the min-p sampling parameter. Defaults to `0.0` (disabled). -- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. - Defaults to `1.0`. +- `repetition_penalty`: (Optional) Applies a multiplicative penalty to repeated + tokens. Defaults to `0.0` (disabled). - `repetition_context_size`: (Optional) The size of the context window for applying repetition penalty. Defaults to `20`. +- `presence_penalty`: (Optional) Applies an additive penalty to tokens + that appeared before. Defaults to `0.0` (disabled). + +- `presence_context_size`: (Optional) The size of the context window for + applying presence penalty. Defaults to `20`. + +- `frequency_penalty`: (Optional) Applies an additive penalty proportional to + how many times a token appeared previously. Defaults to `0.0` (disabled). + +- `frequency_context_size`: (Optional) The size of the context window for + applying frequency penalty. Defaults to `20`. + - `logit_bias`: (Optional) A dictionary mapping token IDs to their bias values. Defaults to `None`. diff --git a/mlx_lm/sample_utils.py b/mlx_lm/sample_utils.py index db371d3b4..e56242763 100644 --- a/mlx_lm/sample_utils.py +++ b/mlx_lm/sample_utils.py @@ -73,15 +73,28 @@ def make_logits_processors( logit_bias: Optional[Dict[int, float]] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, + presence_penalty: Optional[float] = None, + presence_context_size: Optional[int] = 20, + frequency_penalty: Optional[float] = None, + frequency_context_size: Optional[int] = 20, ): """ Make logits processors for use with ``generate_step``. Args: - repetition_penalty (float, optional): The penalty factor for repeating - tokens. + repetition_penalty (float, optional): A (sign-aware) multiplicative + penalty for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. Default: ``20``. + presence_penalty (float, optional): An additive penalty to reduce + repeating tokens. + presence_context_size (int, optional): The number of tokens to consider + for the presence penalty. Default: ``20``. + frequency_penalty (float, optional): An additive penalty to reduce + repeating tokens. The tokens are penalized proportionally to their + frequency. + frequency_context_size (int, optional): The number of tokens to consider + for the frequency penalty. Default: ``20``. logit_bias (dictionary, optional): Additive logit bias. Returns: @@ -96,15 +109,20 @@ def make_logits_processors( values = mx.array(list(logit_bias.values())) def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits + return logits.at[:, indices].add(values) logits_processors.append(logit_bias_processor) - if repetition_penalty and repetition_penalty != 0.0: - logits_processors.append( - make_repetition_penalty(repetition_penalty, repetition_context_size) - ) + repetition_penalties = [ + (make_repetition_penalty, repetition_penalty, repetition_context_size), + (make_presence_penalty, presence_penalty, presence_context_size), + (make_frequency_penalty, frequency_penalty, frequency_context_size), + ] + + for make_penalty, penalty, context_size in repetition_penalties: + if penalty is not None and penalty != 0: + logits_processors.append(make_penalty(penalty, context_size)) + return logits_processors @@ -307,3 +325,58 @@ def repetition_penalty_processor(tokens, logits): return logits return repetition_penalty_processor + + +def make_presence_penalty(penalty: float, context_size: int = 20): + """ + Make a presence penalty processor. + + Corresponds to the OpenAI option with the same name. Namely, subtracts + ``penalty`` from a logit if the token has occured at least once in the + ``context_size`` previous tokens. + + Args: + penalty (float): The presence penalty to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array] + """ + + def presence_penalty_processor(tokens, logits): + if len(tokens) > 0: + tokens = tokens[-context_size:] + logits[:, tokens] -= penalty + return logits + + return presence_penalty_processor + + +def make_frequency_penalty(penalty: float, context_size: int = 20): + """ + Make a frequency penalty processor. + + Corresponds to the OpenAI option with the same name. Namely, subtracts + ``penalty`` from a logit for every time that the token has occured in the + ``context_size`` previous tokens. + + The difference with the presence penalty is that the more often a token + occurs the more it will be penalized. + + Args: + penalty (float): The frequency penalty to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array] + """ + + def frequency_penalty_processor(tokens, logits): + if len(tokens) > 0: + tokens = tokens[-context_size:] + logits = logits.at[:, tokens].subtract(penalty) + return logits + + return frequency_penalty_processor diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 628863ee2..58b369964 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -395,6 +395,10 @@ class LogitsProcessorArguments: logit_bias: Optional[Dict[int, float]] repetition_penalty: float repetition_context_size: int + presence_penalty: float + presence_context_size: int + frequency_penalty: float + frequency_context_size: int @dataclass @@ -623,6 +627,10 @@ def _make_logits_processors(args): args.logits.logit_bias, args.logits.repetition_penalty, args.logits.repetition_context_size, + args.logits.presence_penalty, + args.logits.presence_context_size, + args.logits.frequency_penalty, + args.logits.frequency_context_size, ) @@ -1208,6 +1216,10 @@ def do_POST(self): self.min_p = self.body.get("min_p", self.response_generator.cli_args.min_p) self.repetition_penalty = self.body.get("repetition_penalty", 0.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) + self.presence_penalty = self.body.get("presence_penalty", 0.0) + self.presence_context_size = self.body.get("presence_context_size", 20) + self.frequency_penalty = self.body.get("frequency_penalty", 0.0) + self.frequency_context_size = self.body.get("frequency_context_size", 20) self.xtc_probability = self.body.get("xtc_probability", 0.0) self.xtc_threshold = self.body.get("xtc_threshold", 0.0) self.logit_bias = self.body.get("logit_bias", None) @@ -1256,6 +1268,25 @@ def validate_model_parameters(self): or self.repetition_penalty < 0 ): raise ValueError("repetition_penalty must be a non-negative float") + if ( + not isinstance(self.repetition_context_size, int) + or self.repetition_context_size < 0 + ): + raise ValueError("repetition_context_size must be a non-negative integer") + if not isinstance(self.presence_penalty, (float, int)): + raise ValueError("Presence penalty must be must be a float") + if ( + not isinstance(self.presence_context_size, int) + or self.presence_context_size < 0 + ): + raise ValueError("presence_context_size must be a non-negative integer") + if not isinstance(self.frequency_penalty, (float, int)): + raise ValueError("Presence penalty must be must be a float") + if ( + not isinstance(self.frequency_context_size, int) + or self.frequency_context_size < 0 + ): + raise ValueError("frequency_context_size must be a non-negative integer") if not isinstance(self.logprobs, bool): raise ValueError("logprobs must be a boolean") @@ -1265,12 +1296,6 @@ def validate_model_parameters(self): f"top_logprobs must be between 1 and 10 but got {self.top_logprobs:,}" ) - if ( - not isinstance(self.repetition_context_size, int) - or self.repetition_context_size < 0 - ): - raise ValueError("repetition_context_size must be a non-negative integer") - if self.logit_bias is not None: if not isinstance(self.logit_bias, dict): raise ValueError("logit_bias must be a dict of int to float") @@ -1430,6 +1455,10 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): logit_bias=self.logit_bias, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, + presence_penalty=self.presence_penalty, + presence_context_size=self.presence_context_size, + frequency_penalty=self.frequency_penalty, + frequency_context_size=self.frequency_context_size, ), stop_words=stop_words, max_tokens=self.max_tokens, diff --git a/tests/test_sample_utils.py b/tests/test_sample_utils.py index f587c84af..e4b465401 100644 --- a/tests/test_sample_utils.py +++ b/tests/test_sample_utils.py @@ -116,6 +116,64 @@ def test_apply_xtc(self): new_probs = mx.softmax(apply_xtc(mx.log(probs), 0, 0.1, [0]), -1) self.assertTrue(mx.allclose(new_probs, probs)) + def test_presence_penalty(self): + from mlx_lm.sample_utils import make_presence_penalty + + # Token appears multiple times - penalty applied once + tokens = mx.array([0, 0, 0, 1, 1]) + logits = mx.zeros((1, 4)) + processor = make_presence_penalty(0.5, context_size=5) + result = processor(tokens, logits) + # Token 0 appears 3 times, token 1 appears 2 times - both penalized once + self.assertAlmostEqual(result[0, 0].item(), -0.5) + self.assertAlmostEqual(result[0, 1].item(), -0.5) + # Tokens not in context not penalized + self.assertAlmostEqual(result[0, 2].item(), 0.0) + self.assertAlmostEqual(result[0, 3].item(), 0.0) + + def test_frequency_penalty(self): + from mlx_lm.sample_utils import make_frequency_penalty + + # Token appears multiple times - penalty applied proportionally + tokens = mx.array([0, 0, 0, 1, 1]) + logits = mx.zeros((1, 4)) + processor = make_frequency_penalty(0.5, context_size=5) + result = processor(tokens, logits) + # Token 0 appears 3 times -> 3 * 0.5 = 1.5 penalty + self.assertAlmostEqual(result[0, 0].item(), -1.5) + # Token 1 appears 2 times -> 2 * 0.5 = 1.0 penalty + self.assertAlmostEqual(result[0, 1].item(), -1.0) + # Tokens not in context not penalized + self.assertAlmostEqual(result[0, 2].item(), 0.0) + self.assertAlmostEqual(result[0, 3].item(), 0.0) + + def test_make_logits_processors(self): + from mlx_lm.sample_utils import make_logits_processors + + # Create processors with all three penalty types + tokens = mx.array([0, 0, 0, 1, 1]) + # Use non-zero logits so repetition penalty has effect + logits = mx.array([[1.0, 0.5, 0.0, -0.5]]) + processors = make_logits_processors( + repetition_penalty=1.5, + repetition_context_size=5, + presence_penalty=0.5, + presence_context_size=5, + frequency_penalty=0.25, + frequency_context_size=5, + ) + # Apply all processors + for processor in processors: + logits = processor(tokens, logits) + # Token 0 (appears 3x): 1.0/1.5 - 0.5 - 0.75 = -0.5833 + # Token 1 (appears 2x): 0.5/1.5 - 0.5 - 0.5 = -0.6667 + # Token 2 (not in context): 0.0 (no penalty) + # Token 3 (not in context): -0.5 (no penalty) + self.assertAlmostEqual(logits[0, 0].item(), -0.5833, places=4) + self.assertAlmostEqual(logits[0, 1].item(), -0.6667, places=4) + self.assertAlmostEqual(logits[0, 2].item(), 0.0, places=4) + self.assertAlmostEqual(logits[0, 3].item(), -0.5, places=4) + if __name__ == "__main__": unittest.main()