Skip to content

Commit

Permalink
[AutoParallel] unify llama model
Browse files Browse the repository at this point in the history
  • Loading branch information
deepllz committed Mar 15, 2024
1 parent c406d90 commit 8110fad
Showing 1 changed file with 68 additions and 50 deletions.
118 changes: 68 additions & 50 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,26 @@ def swiglu(x, y=None):
]


def is_pp_enable():
mesh = fleet.auto.get_mesh()
return "pp" in mesh.dim_names

Check warning on line 87 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L86-L87

Added lines #L86 - L87 were not covered by tests


def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh


def global_mesh_starts_with_pp():
mesh = fleet.auto.get_mesh()
if is_pp_enable():
return mesh.get_mesh_with_dim("pp")

Check warning on line 100 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L98-L100

Added lines #L98 - L100 were not covered by tests
else:
return mesh

Check warning on line 102 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L102

Added line #L102 was not covered by tests


def scaled_dot_product_attention(
query_states,
config,
Expand Down Expand Up @@ -800,21 +813,25 @@ def __init__(self, config: LlamaConfig):
[dist.Replicate(), dist.Shard(1)],
)

def get_layer_ipp(layer_index):
def get_layer_pp_info(layer_index):

Check warning on line 816 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L816

Added line #L816 was not covered by tests
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None
if is_pp_enable() is False:
return None, False

Check warning on line 819 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L818-L819

Added lines #L818 - L819 were not covered by tests
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage

self.layers = nn.LayerList(
[
LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i))
for i in range(config.num_hidden_layers)
]
)
input_need_reshard = layer_index % layer_per_stage == 0
return layer_index // layer_per_stage, input_need_reshard

Check warning on line 824 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L823-L824

Added lines #L823 - L824 were not covered by tests

decoder_layers = []
self.next_pp_stage_indexes = []
for i in range(config.num_hidden_layers):
pp_stage_id, input_need_reshard = get_layer_pp_info(i)
decoder_layers.append(LlamaDecoderLayerAuto(config, False, pp_stage_id))
if input_need_reshard:
self.next_pp_stage_indexes.append(i)

Check warning on line 832 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L826-L832

Added lines #L826 - L832 were not covered by tests

self.layers = nn.LayerList(decoder_layers)

Check warning on line 834 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L834

Added line #L834 was not covered by tests
self.norm = LlamaRMSNormAuto(config)

self.gradient_checkpointing = False
Expand All @@ -840,13 +857,6 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length
)
# NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel
combined_attention_mask = dist.shard_tensor(
combined_attention_mask,
get_mesh(),
[dist.Replicate(), dist.Replicate()],
)

expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
Expand Down Expand Up @@ -903,6 +913,20 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self.config.sequence_parallel:

Check warning on line 916 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L916

Added line #L916 was not covered by tests
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

Check warning on line 918 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L918

Added line #L918 was not covered by tests

global_mesh = global_mesh_starts_with_pp()
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

Check warning on line 922 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L920-L922

Added lines #L920 - L922 were not covered by tests

position_ids = dist.shard_tensor(

Check warning on line 924 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L924

Added line #L924 was not covered by tests
position_ids,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

# embed positions
if attention_mask is None:
# [bs, seq_len]
Expand All @@ -914,22 +938,18 @@ def forward(
else:
alibi = None

if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
# NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()])

if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

if self.config.use_flash_attention:
# attention_mask in flash_attn is always None for pretrain
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
attention_mask = dist.shard_tensor(

Check warning on line 948 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L948

Added line #L948 was not covered by tests
attention_mask,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

hidden_states = inputs_embeds
hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements)
Expand All @@ -939,34 +959,34 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

pre_ipp = None
for idx, (decoder_layer) in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None

has_gradient = not hidden_states.stop_gradient
ipp = decoder_layer.ipp
if not is_pp_enable():
position_ids_input = position_ids
attention_mask_input = attention_mask

Check warning on line 971 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L968-L971

Added lines #L968 - L971 were not covered by tests
else:
position_ids_input = dist.reshard(

Check warning on line 973 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L973

Added line #L973 was not covered by tests
position_ids,
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
attention_mask_input = dist.reshard(

Check warning on line 978 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L978

Added line #L978 was not covered by tests
attention_mask,
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)

if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp:
if idx in self.next_pp_stage_indexes:

Check warning on line 984 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L984

Added line #L984 was not covered by tests
hidden_states = dist.reshard(
hidden_states,
get_mesh(decoder_layer.ipp),
get_mesh(ipp),
self.placements,
)
position_ids = dist.reshard(
position_ids,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
)
attention_mask = (
dist.reshard(
attention_mask,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
)
if attention_mask is not None
else attention_mask
)

if (
self.enable_recompute
Expand All @@ -977,8 +997,8 @@ def forward(
layer_outputs = recompute(
decoder_layer,
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
Expand All @@ -987,16 +1007,14 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
alibi=alibi,
)

pre_ipp = decoder_layer.ipp

if type(layer_outputs) is tuple:
hidden_states = layer_outputs[0]
else:
Expand Down

0 comments on commit 8110fad

Please sign in to comment.