Skip to content

Commit

Permalink
Merge pull request #2584 from ming1753/internet
Browse files Browse the repository at this point in the history
support return_all_tokens & stop_seqs
  • Loading branch information
Jiang-Jia-Jun authored Jan 8, 2025
2 parents 608d4be + c249b98 commit 97e541e
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
45 changes: 44 additions & 1 deletion llm/server/server/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def process_request(self, request, max_seq_len=None):
request["eos_token_ids"] = []
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))

if "stop_seqs" not in request or (isinstance(request["stop_seqs"], (list, tuple)) and len(request["stop_seqs"]) == 0):
self.update_stop_seq(request)

if "input_ids" not in request or \
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
if "text" in request:
Expand Down Expand Up @@ -282,7 +285,7 @@ def _load_tokenizer(self):
"""
if self.config.use_hf_tokenizer:
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
else:
from paddlenlp.transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.config.model_dir)
Expand Down Expand Up @@ -334,3 +337,43 @@ def get_pad_id(self):
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
return self.tokenizer.eos_token
return self.tokenizer.pad_token_id

def pad_batch_data(self, insts, pad_id=0, return_seq_len=False, return_array=True, pad_style="right"):
"""Pad the instances to the max sequence length in batch."""
if len(insts) == 0:
padded_insts = np.array([[]], dtype=np.int64) if return_array else [[]]
if return_seq_len:
seq_len = np.array([], dtype=np.int64) if return_array else []
return padded_insts, seq_len
return padded_insts

max_len = max(map(len, insts))
if pad_style == "left":
padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]
else:
padded_insts = [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]
if return_array:
padded_insts = np.array(padded_insts, dtype=np.int64).reshape([-1, max_len])

if return_seq_len:
seq_len = [len(inst) for inst in insts]
if return_array:
seq_len = np.array(seq_len, dtype=np.int64).reshape(-1, 1)
return padded_insts, seq_len
return padded_insts

def update_stop_seq(self, request):
"""
Update stop sequences from request.
"""
stop_seqs = []
for seq in request.get("stop_sequences", []):
if seq != self.tokenizer.eos_token_id:
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
request["stop_seqs"], request["stop_seqs_len"] = self.pad_batch_data(
stop_seqs,
pad_id=-1,
return_seq_len=True,
return_array=False
)
data_processor_logger.debug(f"processed request: {request['stop_seqs'], request['stop_seqs_len']}")
35 changes: 35 additions & 0 deletions llm/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def __init__(self, args):
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
self.args.hidden_size = self.model_cfg["hidden_size"]

self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))

self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))

self.nranks = dist.get_world_size()
self.init_dist_env()
self.rank = fleet.worker_index()
Expand Down Expand Up @@ -246,6 +251,19 @@ def init_inputs(self):
self.share_inputs['free_list_len'] = paddle.full(
shape=[1], fill_value=self.free_list_len, dtype="int32")

self.share_inputs['stop_seqs_len'] = paddle.full(shape=[self.max_stop_seqs_num,],
fill_value=0,
dtype="int32")
self.share_inputs['stop_seqs'] = paddle.full(shape=[self.max_stop_seqs_num, self.stop_seqs_max_len],
fill_value=-1,
dtype="int64")

if self.reduce_dialogue_repetition:
self.share_inputs["first_token_ids"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")

def dy_input_preprocess(self, tasks):
"""
dynamic insertion
Expand Down Expand Up @@ -279,6 +297,10 @@ def dy_input_preprocess(self, tasks):
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
self.share_inputs['stop_flags'][idx:idx + 1] = False

if self.reduce_dialogue_repetition:
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length

if "infer_seed" in task:
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']

Expand All @@ -288,6 +310,14 @@ def dy_input_preprocess(self, tasks):
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
task['block_tables'], dtype="int32")

if "stop_seqs_len" in task:
stop_seqs_num = len(task["stop_seqs_len"])
for i in range(stop_seqs_num, self.max_stop_seqs_num):
task["stop_seqs_len"].append(0)
self.share_inputs['stop_seqs_len'][:] = np.array(
task["stop_seqs_len"], dtype="int32")
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
task["stop_seqs"], dtype="int64")
def step_cuda(self, seq_lens_this_time):
"""
step cuda
Expand Down Expand Up @@ -474,6 +504,11 @@ def _init_predictor(self):
config.switch_ir_optim(False)
config.enable_use_gpu(100, device_id)

pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
if pir_flag == 1:
config.enable_new_executor()
config.enable_new_ir()

# distributed config
if self.mp_degree > 1:
trainer_endpoints = fleet.worker_endpoints()
Expand Down
1 change: 1 addition & 0 deletions llm/server/server/http_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Req(BaseModel):
req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
input_ids: Optional[List[int]] = None
text: Optional[str] = None
stop_sequences: Optional[List] = None
messages: Optional[List] = None
max_dec_len: Optional[int] = None
seq_len: Optional[int] = None
Expand Down
24 changes: 24 additions & 0 deletions llm/server/server/triton_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,35 @@ def _push_mode_sender_thread(self):
except Exception as e:
model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))

def _cache_special_tokens(self, batch_result):
for i in range(len(batch_result)):
is_end = batch_result[i].get("is_end", 0)
token_ids = batch_result[i]["token_ids"]
if is_end != 1:
if batch_result[i]["req_id"] not in self.token_buffer:
self.token_buffer[batch_result[i]["req_id"]] = list()
self.score_buffer[batch_result[i]["req_id"]] = list()
self.token_buffer[batch_result[i]["req_id"]].extend(token_ids)
self.score_buffer[batch_result[i]["req_id"]].extend(batch_result[i].get("token_scores", []))
batch_result[i]["token_ids"] = []
if "token_scores" in batch_result[i]:
batch_result[i]["token_scores"] = []
else:
if batch_result[i]["req_id"] in self.token_buffer:
batch_result[i]["token_ids"] = self.token_buffer[batch_result[i]
["req_id"]] + batch_result[i]["token_ids"]
del self.token_buffer[batch_result[i]["req_id"]]
if "token_scores" in batch_result[i]:
batch_result[i]["token_scores"] = self.score_buffer[batch_result[i]
["req_id"]] + batch_result[i]["token_scores"]
del self.score_buffer[batch_result[i]["req_id"]]

def postprocess(self, batch_result, exist_finished_task=False):
"""
single postprocess for triton
"""
try:
self._cache_special_tokens(batch_result)
self.cached_generated_tokens.put(batch_result)
except Exception as e:
model_server_logger.info(
Expand Down

0 comments on commit 97e541e

Please sign in to comment.