Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 99 additions & 5 deletions paddleformers/data/causal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .indexed_dataset import make_dataset as make_indexed_dataset

local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
INFER_SERVER_IP = os.getenv("INFER_SERVER_IP", "127.0.0.1")
INFER_SERVER_PORT = os.getenv("INFER_SERVER_PORT", "8008")


# class FakeHCG:
Expand All @@ -25,6 +27,57 @@
# def get_model_parallel_group(self):
# return None

import pickle

import requests


def get_logits(batch_ids, max_retries=1, timeout=1200, retry_delay=1, prob_nums=10):
"""
Retrieve logits with retry mechanism if no response is received within the specified time

Parameters:
batch_ids: Input token ids
max_retries: Maximum number of retry attempts (default: 1)
timeout: Request timeout in seconds (default: 1200 seconds)
retry_delay: Delay between retries in seconds (default: 1 second)
prob_nums: Number of probabilities to return

Returns:
tuple: (logits, ids)
Raises:
Exception: Thrown when all retry attempts fail
"""

headers = {"Content_Type": "application/json"}
url = f"http://{INFER_SERVER_IP}:{INFER_SERVER_PORT}/generate"
payload = {
"prompt_token_ids": batch_ids,
"max_tokens": 1,
"top_p": 1,
"top_k": -1,
"temperature": 1,
"prompt_logprobs": prob_nums,
"logprobs": prob_nums,
}

for attempt in range(max_retries):
try:
response = requests.post(url=url, json=payload, headers=headers, timeout=timeout)
response.raise_for_status() # 检查HTTP错误

data = pickle.loads(response.content)
all_token = data.get("logits", [])
all_ids = data.get("ids", [])
all_token = paddle.to_tensor(all_token, dtype="bfloat16")
all_ids = paddle.to_tensor(all_ids, dtype="int64")
return all_token, all_ids

except (requests.exceptions.RequestException, IOError) as e:
if attempt == max_retries - 1:
raise Exception(f"Failed after {max_retries} attempts. Last error: {str(e)}")
time.sleep(retry_delay)


def check_data_split(splits_string, do_train, do_eval, do_predict):
splits = []
Expand Down Expand Up @@ -125,6 +178,8 @@ def build_train_valid_test_datasets(
*,
data_cache_path=None,
need_data=True,
self_constraint_cpt=False,
prob_nums=10,
):
"""Build train, valid, and test datasets."""

Expand All @@ -141,6 +196,8 @@ def build_train_valid_test_datasets(
share_folder=share_folder,
data_cache_path=data_cache_path,
need_data=need_data,
self_constraint_cpt=self_constraint_cpt,
prob_nums=prob_nums,
)

# Blending dataset.
Expand Down Expand Up @@ -168,6 +225,8 @@ def build_train_valid_test_datasets(
share_folder=share_folder,
data_cache_path=data_cache_path,
need_data=need_data,
self_constraint_cpt=self_constraint_cpt,
prob_nums=prob_nums,
)
if train_ds:
train_datasets.append(train_ds)
Expand Down Expand Up @@ -212,8 +271,14 @@ def _build_train_valid_test_datasets(
*,
data_cache_path=None,
need_data=True,
self_constraint_cpt=False,
prob_nums=10,
):
"""Build train, valid, and test datasets."""
CPT = False
if data_prefix.endswith("::CPT"):
data_prefix = data_prefix[: -len("::CPT")]
CPT = True

# Indexed dataset.
if need_data:
Expand Down Expand Up @@ -254,6 +319,9 @@ def build_dataset(index, name):
share_folder,
data_cache_path=data_cache_path,
need_data=need_data,
CPT=CPT,
self_constraint_cpt=self_constraint_cpt,
prob_nums=prob_nums,
)
if need_data:
return dataset if splits[index + 1] > splits[index] else None
Expand Down Expand Up @@ -295,11 +363,17 @@ def __init__(
*,
data_cache_path=None,
need_data=True,
CPT=False,
self_constraint_cpt=False,
prob_nums=10,
):

self.name = name
self.indexed_dataset = indexed_dataset
self.return_doc_ids = return_doc_ids
self.CPT = CPT
self.self_constraint_cpt = self_constraint_cpt
self.prob_nums = prob_nums

# Build index mappings.
if need_data and len(documents) > 0:
Expand Down Expand Up @@ -397,21 +471,41 @@ def __getitem__(self, idx):
sample = np.concatenate(sample_list)
if append_mask:
mask = np.concatenate(mask_list)
# print(sample)

kl_logits, kl_ids = None, None
if self.self_constraint_cpt:
kl_logits, kl_ids = get_logits([sample.tolist()], prob_nums=self.prob_nums)

res = None
if self.return_doc_ids: # for retro preprocessing
if mask is None:
return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)}
res = {
"text": np.array(sample, dtype=np.int64),
"doc_ids": np.array(doc_ids, dtype=np.int64),
}
else:
return {
res = {
"text": np.array(sample, dtype=np.int64),
"doc_ids": np.array(doc_ids, dtype=np.int64),
"mask": np.array(mask, dtype=np.int64),
}
else:
if mask is None:
return {"text": np.array(sample, dtype=np.int64)}
res = {"text": np.array(sample, dtype=np.int64)}
else:
return {"text": np.array(sample, dtype=np.int64), "mask": np.array(mask, dtype=np.int64)}
res = {
"text": np.array(sample, dtype=np.int64),
"mask": np.array(mask, dtype=np.int64),
}

if self.self_constraint_cpt:
res.update = {
"logits": kl_logits,
"ids": kl_ids,
"CPT": self.CPT,
}

return res


def _build_index_mappings(
Expand Down