Skip to content

Commit

Permalink
support length penalty (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANDgate99 authored Mar 18, 2024
1 parent f1f66fe commit b445a6b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
4 changes: 3 additions & 1 deletion lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import collections

from dataclasses import dataclass, field
from typing import List, Dict
from typing import List, Dict, Tuple
from lightllm.common.req_manager import ReqManager
from lightllm.common.mem_manager import MemoryManager
from lightllm.utils.infer_utils import mark_start, mark_end
Expand All @@ -25,6 +25,7 @@ def __init__(
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
exponential_decay_length_penalty: Tuple[int, float] = (1, 1.0),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
Expand All @@ -34,6 +35,7 @@ def __init__(
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
self.exponential_decay_length_penalty = exponential_decay_length_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down
7 changes: 4 additions & 3 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def exposed_init_model(self, kvargs):

try:
self.model_type = model_cfg.get("model_type", "")
self.eos_id = model_cfg.get("eos_token_id", 2)
if self.model_type == "bloom":
self.model = BloomTpPartModel(model_kvargs)
elif self.model_type == "llama":
Expand Down Expand Up @@ -279,7 +280,7 @@ def forward(self, batch_id, is_prefill):
kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache, self.model.mem_manager)

logits = self.model.forward(**kwargs)
next_token_ids, next_token_probs = sample(logits, run_reqs)
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
next_token_ids = next_token_ids.detach().cpu().numpy()
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()

Expand Down Expand Up @@ -317,7 +318,7 @@ def _prefill_to_return_all_prompt_logprobs(self, batch_id):
last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1
logits = prompt_all_logits[last_index, :]

next_token_ids, next_token_probs = sample(logits, run_reqs)
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
next_token_ids = next_token_ids.detach().cpu().numpy()
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()

Expand Down Expand Up @@ -370,7 +371,7 @@ def splitfuse_forward(self, batch_id):
all_reqs.extend(prefill_reqs)

logits = self.model.splitfuse_forward(**kwargs)
next_token_ids, next_token_probs = sample(logits, all_reqs)
next_token_ids, next_token_probs = sample(logits, all_reqs, self.eos_id)
next_token_ids = next_token_ids.detach().cpu().numpy()
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()

Expand Down
16 changes: 12 additions & 4 deletions lightllm/server/router/model_infer/post_process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import re
import torch
from typing import List
from typing import List, Tuple
from lightllm.server.router.model_infer.infer_batch import InferBatch
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty

def sample(logits, reqs):
def sample(logits, reqs, eos_id=2):
logits = logits.contiguous()
presence_penalties, frequency_penalties, repetition_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch = _get_post_sample_tensors(reqs)
presence_penalties, frequency_penalties, repetition_penalties, exponential_decay_length_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch, length_penalty_idx = _get_post_sample_tensors(reqs)

apply_penalty(logits, presence_penalties, frequency_penalties, repetition_penalties, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch)
logits[:, eos_id] = logits[:, eos_id] + torch.abs(logits[:, eos_id]) * (torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1, 1)) - 1)
logits.div_(temperatures.view((-1, 1)))
probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks)
Expand All @@ -33,19 +34,24 @@ def _get_post_sample_tensors(reqs):
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
exponential_decay_length_penalties: List[float] = []
temperatures: List[float] = []
top_ps: List[float] = []
top_ks: List[int] = []
p_token_ids: List[int] = []
p_token_counts: List[int] = []
p_seq_len: List[int] = [0,]
p_max_len_in_batch: int = 0
length_penalty_idx: List[int] = []
for i, req_obj in enumerate(reqs):
id_to_count = req_obj.out_token_id_count
sample_param = req_obj.sampling_param
presence_penalties.append(sample_param.presence_penalty)
frequency_penalties.append(sample_param.frequency_penalty)
repetition_penalties.append(sample_param.repetition_penalty)
exponential_decay_length_penalties.append(sample_param.exponential_decay_length_penalty[1])
length_penalty_idx.append(max(len(req_obj.input_token_ids) - req_obj.prompt_len - sample_param.exponential_decay_length_penalty[0], 0))

temperatures.append(sample_param.temperature)
top_ps.append(sample_param.top_p)
top_ks.append(sample_param.top_k)
Expand All @@ -59,11 +65,13 @@ def _get_post_sample_tensors(reqs):
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
repetition_penalties = torch.tensor(repetition_penalties, dtype=torch.float, device="cuda")
exponential_decay_length_penalties = torch.tensor(exponential_decay_length_penalties, dtype=torch.float, device="cuda")
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
return presence_penalties, frequency_penalties, repetition_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch
length_penalty_idx = torch.tensor(length_penalty_idx, dtype=torch.int32, device="cuda")
return presence_penalties, frequency_penalties, repetition_penalties, exponential_decay_length_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch, length_penalty_idx
11 changes: 10 additions & 1 deletion lightllm/server/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple

_SAMPLING_EPS = 1e-5

Expand All @@ -12,6 +12,7 @@ def __init__(
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
exponential_decay_length_penalty: Tuple[int, float] = (1, 1.0),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1, # -1 is for all
Expand All @@ -23,6 +24,7 @@ def __init__(
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
self.exponential_decay_length_penalty = exponential_decay_length_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down Expand Up @@ -53,6 +55,12 @@ def verify(self):
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
if len(self.exponential_decay_length_penalty) != 2:
raise ValueError(f"exponential_decay_length_penalty must be a tuple of (int, float), got {self.exponential_decay_length_penalty}.")
if not isinstance(self.exponential_decay_length_penalty[0], int) or self.exponential_decay_length_penalty[0] < 0:
raise ValueError(f"exponential_decay_length_penalty[0] must be a non-negative integer, got {self.exponential_decay_length_penalty[0]}.")
if not isinstance(self.exponential_decay_length_penalty[1], float) or self.exponential_decay_length_penalty[1] < 1.0:
raise ValueError(f"exponential_decay_length_penalty[1] must be a float >= 1.0, got {self.exponential_decay_length_penalty[1]}.")
return

def stop_sentences_to_token_ids(self, tokenizer):
Expand All @@ -77,6 +85,7 @@ def to_dict(self):
ret["presence_penalty"] = self.presence_penalty
ret["frequency_penalty"] = self.frequency_penalty
ret["repetition_penalty"] = self.repetition_penalty
ret["exponential_decay_length_penalty"] = self.exponential_decay_length_penalty
ret["temperature"] = self.temperature
ret["top_p"] = self.top_p
ret["top_k"] = self.top_k
Expand Down

0 comments on commit b445a6b

Please sign in to comment.