diff --git a/paddleformers/data/causal_dataset.py b/paddleformers/data/causal_dataset.py index d3789287649..fc5169877ad 100644 --- a/paddleformers/data/causal_dataset.py +++ b/paddleformers/data/causal_dataset.py @@ -13,6 +13,8 @@ from .indexed_dataset import make_dataset as make_indexed_dataset local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) +INFER_SERVER_IP = os.getenv("INFER_SERVER_IP", "127.0.0.1") +INFER_SERVER_PORT = os.getenv("INFER_SERVER_PORT", "8008") # class FakeHCG: @@ -25,6 +27,57 @@ # def get_model_parallel_group(self): # return None +import pickle + +import requests + + +def get_logits(batch_ids, max_retries=1, timeout=1200, retry_delay=1, prob_nums=10): + """ + Retrieve logits with retry mechanism if no response is received within the specified time + + Parameters: + batch_ids: Input token ids + max_retries: Maximum number of retry attempts (default: 1) + timeout: Request timeout in seconds (default: 1200 seconds) + retry_delay: Delay between retries in seconds (default: 1 second) + prob_nums: Number of probabilities to return + + Returns: + tuple: (logits, ids) + Raises: + Exception: Thrown when all retry attempts fail + """ + + headers = {"Content_Type": "application/json"} + url = f"http://{INFER_SERVER_IP}:{INFER_SERVER_PORT}/generate" + payload = { + "prompt_token_ids": batch_ids, + "max_tokens": 1, + "top_p": 1, + "top_k": -1, + "temperature": 1, + "prompt_logprobs": prob_nums, + "logprobs": prob_nums, + } + + for attempt in range(max_retries): + try: + response = requests.post(url=url, json=payload, headers=headers, timeout=timeout) + response.raise_for_status() # 检查HTTP错误 + + data = pickle.loads(response.content) + all_token = data.get("logits", []) + all_ids = data.get("ids", []) + all_token = paddle.to_tensor(all_token, dtype="bfloat16") + all_ids = paddle.to_tensor(all_ids, dtype="int64") + return all_token, all_ids + + except (requests.exceptions.RequestException, IOError) as e: + if attempt == max_retries - 1: + raise Exception(f"Failed after {max_retries} attempts. Last error: {str(e)}") + time.sleep(retry_delay) + def check_data_split(splits_string, do_train, do_eval, do_predict): splits = [] @@ -125,6 +178,8 @@ def build_train_valid_test_datasets( *, data_cache_path=None, need_data=True, + self_constraint_cpt=False, + prob_nums=10, ): """Build train, valid, and test datasets.""" @@ -141,6 +196,8 @@ def build_train_valid_test_datasets( share_folder=share_folder, data_cache_path=data_cache_path, need_data=need_data, + self_constraint_cpt=self_constraint_cpt, + prob_nums=prob_nums, ) # Blending dataset. @@ -168,6 +225,8 @@ def build_train_valid_test_datasets( share_folder=share_folder, data_cache_path=data_cache_path, need_data=need_data, + self_constraint_cpt=self_constraint_cpt, + prob_nums=prob_nums, ) if train_ds: train_datasets.append(train_ds) @@ -212,8 +271,14 @@ def _build_train_valid_test_datasets( *, data_cache_path=None, need_data=True, + self_constraint_cpt=False, + prob_nums=10, ): """Build train, valid, and test datasets.""" + CPT = False + if data_prefix.endswith("::CPT"): + data_prefix = data_prefix[: -len("::CPT")] + CPT = True # Indexed dataset. if need_data: @@ -254,6 +319,9 @@ def build_dataset(index, name): share_folder, data_cache_path=data_cache_path, need_data=need_data, + CPT=CPT, + self_constraint_cpt=self_constraint_cpt, + prob_nums=prob_nums, ) if need_data: return dataset if splits[index + 1] > splits[index] else None @@ -295,11 +363,17 @@ def __init__( *, data_cache_path=None, need_data=True, + CPT=False, + self_constraint_cpt=False, + prob_nums=10, ): self.name = name self.indexed_dataset = indexed_dataset self.return_doc_ids = return_doc_ids + self.CPT = CPT + self.self_constraint_cpt = self_constraint_cpt + self.prob_nums = prob_nums # Build index mappings. if need_data and len(documents) > 0: @@ -397,21 +471,41 @@ def __getitem__(self, idx): sample = np.concatenate(sample_list) if append_mask: mask = np.concatenate(mask_list) - # print(sample) + + kl_logits, kl_ids = None, None + if self.self_constraint_cpt: + kl_logits, kl_ids = get_logits([sample.tolist()], prob_nums=self.prob_nums) + + res = None if self.return_doc_ids: # for retro preprocessing if mask is None: - return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)} + res = { + "text": np.array(sample, dtype=np.int64), + "doc_ids": np.array(doc_ids, dtype=np.int64), + } else: - return { + res = { "text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64), "mask": np.array(mask, dtype=np.int64), } else: if mask is None: - return {"text": np.array(sample, dtype=np.int64)} + res = {"text": np.array(sample, dtype=np.int64)} else: - return {"text": np.array(sample, dtype=np.int64), "mask": np.array(mask, dtype=np.int64)} + res = { + "text": np.array(sample, dtype=np.int64), + "mask": np.array(mask, dtype=np.int64), + } + + if self.self_constraint_cpt: + res.update = { + "logits": kl_logits, + "ids": kl_ids, + "CPT": self.CPT, + } + + return res def _build_index_mappings(