Skip to content

Commit

Permalink
Adapter new type promotion rule for Paddle 2.6 (PaddlePaddle#8079)
Browse files Browse the repository at this point in the history
  • Loading branch information
zxcd authored Mar 19, 2024
1 parent 7370d72 commit d46a96e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/layers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _trans_score(self, labels, lengths):
flattened_transition_indices = transition_indices.reshape([-1])
flattened_transition_params = paddle.flatten(self.transitions)
scores = paddle.gather(flattened_transition_params, flattened_transition_indices).reshape([batch_size, -1])
mask_scores = scores * mask[:, 1:]
mask_scores = scores * mask[:, 1:].astype(scores.dtype)

# Accumulate the transition score
score = paddle.sum(mask_scores, 1)
Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def build_alibi_tensor(attention_mask: Tensor, num_heads: int, dtype) -> Tensor:
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.astype(paddle.float32).cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
arange_tensor = (
(attention_mask.astype(paddle.float32).cumsum(axis=-1) - 1) * attention_mask.astype(paddle.float32)
)[:, None, :]
alibi = slopes[..., None] * arange_tensor
# return alibi
return paddle.cast(alibi, dtype)
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/electra/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def get_discriminator_inputs(self, inputs, raw_inputs, generator_logits, generat
mask_positions = paddle.where(generator_labels == -100, umask_positions, mask_positions)
updated_inputs = self.update_inputs(inputs, sampled_tokids, mask_positions)
# use inputs and updated_input to get discriminator labels
labels = mask_positions * (paddle.ones_like(inputs) - paddle.equal(updated_inputs, raw_inputs).astype("int32"))
labels = mask_positions * (paddle.ones_like(inputs) - paddle.equal(updated_inputs, raw_inputs).astype("int64"))
return updated_inputs, labels, sampled_tokids

def sample_from_softmax(self, logits, use_softmax_sample=True):
Expand Down

0 comments on commit d46a96e

Please sign in to comment.