Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick bugfix PRs #9089

Merged
merged 13 commits into from
Sep 5, 2024
7 changes: 5 additions & 2 deletions paddlenlp/data/causal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ def get_datasets_weights_and_num_samples(data_prefix, train_val_test_num_samples
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
# (NOTE, yujun06): This is a workaround to avoid issues with indexing in the blending dataset. Therefore, we need to add 20 samples to each dataset.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) for val in train_val_test_num_samples]
[int(math.ceil(val * weight * 1.005)) + 20 for val in train_val_test_num_samples]
)

return prefixes, weights, datasets_train_valid_test_num_samples
Expand Down Expand Up @@ -146,7 +147,9 @@ def build_train_valid_test_datasets(
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, train_val_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
train_num_samples, valid_num_samples, test_num_samples = map(sum, zip(*datasets_train_valid_test_num_samples))
# NOTE: megatron/gpt_dataset.py has been updated. When creating BlendableDataset, we will use the raw train_val_test_num_samples instead of the expanded ones.
# Please refer to https://github.com/NVIDIA/NeMo/blob/72f630d087d45655b1a069dc72debf01dfdbdb2d/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py#L74-L80 for more information
train_num_samples, valid_num_samples, test_num_samples = train_val_test_num_samples

# Build individual datasets.
train_datasets = []
Expand Down
31 changes: 15 additions & 16 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,11 +1051,12 @@ def _inner_training_loop(
if optimizer_was_run:
self.lr_scheduler.step()

if enable_release_grads and args.pipeline_parallel_degree > 1:
if args.release_grads or enable_release_grads:
self.optimizer.clear_grad(set_to_zero=False)
for _, buffers in model._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer._clear_grad_storage()
if args.pipeline_parallel_degree > 1:
for _, buffers in model._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer._clear_grad_storage()
else:
self.optimizer.clear_grad()

Expand Down Expand Up @@ -1434,6 +1435,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=0,
eval=True,
)
Expand All @@ -1442,6 +1444,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=0,
)

Expand All @@ -1454,15 +1457,15 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
num_workers=0,
eval=True,
)
else:
return DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
num_workers=0,
)

def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
Expand Down Expand Up @@ -1500,13 +1503,15 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=0,
drop_last=self.args.dataloader_drop_last,
eval=True,
)
else:
return DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
drop_last=self.args.dataloader_drop_last,
num_workers=0,
)

Expand All @@ -1520,6 +1525,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
num_workers=0,
drop_last=self.args.dataloader_drop_last,
eval=True,
)
Expand All @@ -1529,6 +1535,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=0,
)

def create_optimizer_and_scheduler(self, num_training_steps: int):
Expand Down Expand Up @@ -1748,16 +1755,8 @@ def _wrap_model(self, model, training=True):
in_sep_parallel_mode = self.args.sep_parallel_degree > 1

# Multi-gpu training
if (
self.args.world_size > 1
and not self.args.use_hybrid_parallel
or not (
in_pipeline_parallel_mode
or in_sharding_parallel_mode
or in_tensor_parallel_mode
or in_sep_parallel_mode
)
):
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
# MOE use DDP to broadcaset parameters.
model = paddle.DataParallel(model)
# Distributed training (should be after fp16 initialization)

Expand Down
9 changes: 7 additions & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ class TrainingArguments:
Whether skip profile timer, timer will record time usage of forward/ backward/ step, etc.
distributed_dataloader (`bool`, *optional*):
Whether to use distributed dataloader. Default is `False`.
release_grads (`bool`, *optional*):
Whether to release gradients during training. Default is `False`.
"""

output_dir: str = field(
Expand Down Expand Up @@ -787,6 +789,9 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
release_grads: Optional[bool] = field(
default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."}
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -1030,7 +1035,7 @@ def __post_init__(self):
"dp_comm_overlap": enable_dp_comm_overlap,
"sharding_comm_overlap": enable_sharding_comm_overlap,
"enable_timer": "enable_timer" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads,
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
"clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config,
"use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,
Expand Down Expand Up @@ -1400,7 +1405,7 @@ def is_segment_parallel_supported():
if world_size > 1:
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
if self.unified_checkpoint:
self.use_hybrid_parallel = True
# DP use hybrid group
strategy = fleet.DistributedStrategy()
fleet.init(is_collective=True, strategy=strategy)
else:
Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def _load_state_dict_into_meta_model(

dtype = convert_np_dtype_to_dtype_(dtype)
error_msgs = []

model_state_dict = model.state_dict()
for param_name, param in state_dict.items():
# First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
Expand Down Expand Up @@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model(
if old_param is not None:
param = param.astype(dtype=old_param.dtype)
with paddle.no_grad():
model.state_dict()[param_name].get_tensor()._share_data_with(param.value().get_tensor())
model_state_dict[param_name].get_tensor()._share_data_with(param.value().get_tensor())
param.value().get_tensor()._clear()
return error_msgs

Expand Down Expand Up @@ -1890,7 +1890,7 @@ def _find_mismatched_keys(
if (
shard_file.endswith(".safetensors")
and config.tensor_parallel_degree > 1
and "tp" not in shard_file
and "tp" not in os.path.split(shard_file)[-1]
):
pre_tensor_parallel_split = True
assert loaded_keys is not None, "loaded_keys is not None."
Expand Down
6 changes: 2 additions & 4 deletions paddlenlp/utils/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from __future__ import division, print_function

import math

import paddle

__all__ = ["DistributedBatchSampler"]
Expand Down Expand Up @@ -110,7 +108,7 @@ def __init__(
# In pre-training mode when using distributed dataloader, the input dataset can be None. We should handle this situation.
self.num_samples = 0
else:
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.num_samples = int(len(self.dataset) * 1.0 / self.nranks)
self.total_size = self.num_samples * self.nranks

def get_start_end_idx(self):
Expand All @@ -125,7 +123,7 @@ def __iter__(self):
self.consumed_samples,
self.nranks,
)
self.remain_num_samples = int(math.ceil((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks))
self.remain_num_samples = int((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks)
self.remain_total_size = self.remain_num_samples * self.nranks
self.batch_size_times_rank_size = self.batch_size * self.nranks

Expand Down
12 changes: 6 additions & 6 deletions paddlenlp/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,16 @@ def __getitem__(self, index):

out_start, out_stop, out_step = copy.deepcopy((self.start, self.stop, self.step))
for i, (start, stop, step, slice_) in enumerate(zip(self.start, self.stop, self.step, index)):
out_start[i] = slice_.start or 0
out_step[i] = slice_.step or 1
out_stop[i] = slice_.stop or stop - start
out_start[i] = slice_.start if slice_.start is not None else 0
out_step[i] = slice_.step if slice_.step is not None else 1
out_stop[i] = slice_.stop if slice_.stop is not None else stop - start
out_stop[i] = min(stop, out_stop[i])

target_shape = []
for x, y, z in zip(out_start, out_stop, out_step):
for x, y, z, sli in zip(out_start, out_stop, out_step, index):
assert z == 1, "only support step = 1"
if y - x > 1:
target_shape.append(int(y - x))
if y - x > 1 or sli.step is None:
target_shape.append(max(int(y - x), 0))

if len(target_shape) == 0:
if self.shape == [1]:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ exclude = ['.flake8']

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q --ignore model_zoo/gpt-3/"
addopts = "-ra -q --dist loadgroup"
pythonpath = ["."]
testpaths = [
"tests/data",
Expand All @@ -28,7 +28,7 @@ testpaths = [
"tests/prompt",
# "tests/taskflow", TODO (paddle 2.5.1 breaks this test suite, debug later)
"tests/utils",
"model_zoo",
# "model_zoo",
]
python_files = [
"test.py",
Expand Down
Loading
Loading