Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
From00 committed Feb 13, 2025
1 parent bbbeb62 commit 382ff07
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 38 deletions.
10 changes: 6 additions & 4 deletions llm/auto_parallel/llama/run_llama2_13b_xpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ unset PADDLE_TRAINERS_NUM
#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4
#export BKCL_SOCKET_IFNAME=xgbe0
#export BKCL_FORCE_L3_RDMA=0
export LD_LIBRARY_PATH=/usr/local/lib:/usr/lib64
echo "bkcl version:"
strings ${bkcl_location}/libbkcl.so | grep COM

Expand All @@ -52,8 +53,8 @@ export CUDA_DEVICE_MAX_CONNECTIONS=8
export PYTHONPATH=../../../:$PYTHONPATH

# for debug
#export GLOG_v=6
#export FLAGS_call_stack_level=2
#export GLOG_v=10
export FLAGS_call_stack_level=2

rm -rf output/$task_name_or_path
PYTHONPATH=../:$PYTHONPATH \
Expand Down Expand Up @@ -92,7 +93,7 @@ python -u -m paddle.distributed.launch \
--dataloader_num_workers 4 \
--pipeline_parallel_degree 1 \
--tensor_parallel_degree 1 \
--gradient_accumulation_steps 32 \
--gradient_accumulation_steps 1 \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
Expand All @@ -101,4 +102,5 @@ python -u -m paddle.distributed.launch \
--do_train \
--seed 1026 \
--device "xpu" \
--enable_auto_parallel 1
--enable_auto_parallel 1 \
--to_static 1
5 changes: 0 additions & 5 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tqdm.auto import tqdm

from paddlenlp.trainer import Trainer
from paddlenlp.utils.tools import get_env_device

from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
from ..utils.log import logger
Expand Down Expand Up @@ -523,10 +522,6 @@ def _inner_training_loop(

logger.info("\nTraining completed. \n")

# Hack for XPU that doesn't support Allgather yet. See LlamaPretrainingCriterion3DAuto in modeling_auto.py for details.
if get_env_device() == "xpu":
tr_loss = tr_loss.mean()

self._total_loss_scalar += self._get_item_from_loss(tr_loss)
train_loss = self._total_loss_scalar / self.state.global_step

Expand Down
50 changes: 21 additions & 29 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def scaled_dot_product_attention(
return (attn_output, attn_weights) if output_attentions else attn_output


colwise_placements = [dist.Replicate(), dist.Shard(1)]
rowise_placement = [dist.Replicate(), dist.Shard(0)]


class LlamaRMSNormAuto(nn.Layer):
def __init__(self, config, ipp):
super().__init__()
Expand Down Expand Up @@ -237,16 +241,6 @@ def __init__(self, config, ipp: Optional[int] = None):
self.fuse_attention_ffn = config.fuse_attention_ffn
self.ipp = ipp
self.config = config
colwise_placements = (
[dist.Replicate(), dist.Shard(1)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)
rowise_placement = (
[dist.Replicate(), dist.Shard(0)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)

if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
Expand Down Expand Up @@ -316,17 +310,6 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
self.recompute_granularity = config.recompute_granularity
self.ipp = ipp

colwise_placements = (
[dist.Replicate(), dist.Shard(1)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)
rowise_placement = (
[dist.Replicate(), dist.Shard(0)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
Expand Down Expand Up @@ -1201,10 +1184,23 @@ def forward(self, prediction_scores, masked_lm_labels):
masked_lm_labels.unsqueeze(2),
)

# Hack for XPU that doesn't support Allgather yet.
# XPU dose not support allgather mask with bool dtype, so we use LocalLayer here.
if get_env_device() == "xpu":
# masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss, axis=-1)

class LocalLossLayer(paddle.distributed.LocalLayer):
def __init__(self, out_dist_attrs):
super().__init__(out_dist_attrs)

def forward(self, x, mask):
masked_lm_loss = paddle.masked_select(x, mask).astype("float32")
loss = paddle.mean(masked_lm_loss)
return loss

out_dist_attrs = [
(masked_lm_loss.process_mesh, [dist.Partial(dist.ReduceType.kRedSum), dist.Replicate()]),
]
loss_func = LocalLossLayer(out_dist_attrs)
loss = loss_func(masked_lm_loss, masked_lm_loss > 0)
else:
masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)
Expand All @@ -1216,11 +1212,7 @@ class LlamaLMHeadAuto(nn.Layer):
def __init__(self, config: LlamaConfig):
super(LlamaLMHeadAuto, self).__init__()
self.config = config
colwise_placements = (
[dist.Replicate(), dist.Shard(1)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)

vocab_size = config.vocab_size
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
Expand Down

0 comments on commit 382ff07

Please sign in to comment.