diff --git a/examples/model_configs/vllm_model_config.yaml b/examples/model_configs/vllm_model_config.yaml index be8941a66..5192cb558 100644 --- a/examples/model_configs/vllm_model_config.yaml +++ b/examples/model_configs/vllm_model_config.yaml @@ -10,5 +10,5 @@ model: top_k: -1 min_p: 0.0 top_p: 0.9 - max_new_tokens: 100 + max_new_tokens: 256 stop_tokens: ["", ""] diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 1742e6f30..105b3a22a 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -316,7 +316,9 @@ def _generate( sampling_params = self.sampling_params.clone() or SamplingParams() if generate: sampling_params.n = num_samples - sampling_params.max_tokens = max_new_tokens + sampling_params.max_tokens = ( + max_new_tokens if sampling_params.max_tokens is None else sampling_params.max_tokens + ) sampling_params.stop = stop_tokens sampling_params.logprobs = 1 if returns_logits else 0