Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ft when using padded vocabulary #3936

Merged
merged 12 commits into from
Jan 4, 2023
Merged
10 changes: 9 additions & 1 deletion paddlenlp/ops/fast_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,7 +1687,15 @@ def transfer_param(p, is_bias=False, dtype="float16", restore_data=False):
if str(p.dtype)[-len(dtype) :] == dtype and ("gpu" in str(p.place).lower() or "cuda" in str(p.place).lower()):
return p
if restore_data:
if paddle.in_dynamic_mode():
if (
getattr(paddle.fluid.framework, "_in_eager_mode_", False)
and getattr(paddle.fluid.framework, "_dygraph_tracer_", None) is not None
):
param_data = p.numpy()
new_p = paddle.create_parameter(shape=param_shape, dtype=dtype, is_bias=is_bias)
new_p.set_value(param_data.astype(dtype))
return new_p
elif paddle.in_dynamic_mode():
param_data = p.numpy()
# Creating parameters with Assign initializer is too slow. Maybe we
# can cast to fp16 directly and get a tensor, while we do it more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,10 @@ __global__ void apply_logits_mask_kernel(int vocab_size_padded,
bool finish = (finished != nullptr) ? finished[bbid] : false;

if (!finish) {
for (int i = tid + bid * blockDim.x; i < vocab_size;
for (int i = tid + bid * blockDim.x; i < vocab_size_padded;
i += blockDim.x * gridDim.x) {
if (min_penalty && i == end_id) {
log_probs[i + bbid * vocab_size_padded] += -MAX_T_VAL;
if ((min_penalty && i == end_id) || i >= vocab_size) {
log_probs[i + bbid * vocab_size_padded] = -MAX_T_VAL;
} else if (logits_mask) {
log_probs[i + bbid * vocab_size_padded] += logits_mask[i];
} else if (bias) {
Expand All @@ -377,7 +377,7 @@ void apply_logits_mask_kernelLauncher(T* log_probs,
const bool min_penalty,
const int end_id,
const T* bias) {
if (logits_mask == nullptr && !min_penalty && bias == nullptr) return;
if (logits_mask == nullptr && !min_penalty && bias == nullptr && vocab_size == vocab_size_padded) return;

dim3 block(256);
dim3 grid((vocab_size_padded + block.x - 1) / block.x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,9 @@ class DecodingBeamsearch {
check_cuda_error(cudaGetLastError());
#endif

if (decoding_params.logits_mask || (args_.min_length_ != 0 && step <= args_.min_length_)) {
if (decoding_params.logits_mask ||
(args_.min_length_ != 0 && step <= args_.min_length_) ||
args_.vocab_size_padded_ != args_.vocab_size_) {
apply_logits_mask_kernelLauncher(
tmp_logits_buf_,
keep_alive_beam_ ? alive_finished_buf_ : finished_buf_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,9 @@ class DecodingSampling {
#endif
}

if (decoding_params.logits_mask || (args_.min_length_ != 0 && step <= args_.min_length_)) {
if (decoding_params.logits_mask ||
(args_.min_length_ != 0 && step <= args_.min_length_) ||
args_.vocab_size_padded_ != args_.vocab_size_) {
apply_logits_mask_kernelLauncher(logits_buf_,
finished_buf_,
args_.batch_size_,
Expand Down
14 changes: 13 additions & 1 deletion paddlenlp/transformers/bart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,23 @@ def forward(
cache = self.decoder.decoder.gen_cache(encoder_last_hidden_state)
else:
cache = None

memory_mask = attention_mask
if attention_mask is not None:
if attention_mask.ndim == 4:
memory_mask = attention_mask[:, :, -1:, :]
elif attention_mask.ndim == 3:
memory_mask = attention_mask[:, -1:, :].unsqueeze([1])
elif attention_mask.ndim == 2:
memory_mask = attention_mask.unsqueeze([1, 2])
else:
raise ValueError("Invalid attention mask shape. ")

decoder_output = self.decoder(
decoder_input_ids,
decoder_attention_mask,
encoder_last_hidden_state,
attention_mask,
memory_mask,
cache=cache,
decoder_inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
Expand Down
14 changes: 13 additions & 1 deletion paddlenlp/transformers/mbart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,23 @@ def forward(
cache = self.decoder.decoder.gen_cache(encoder_last_hidden_state)
else:
cache = None

memory_mask = attention_mask
if attention_mask is not None:
if attention_mask.ndim == 4:
memory_mask = attention_mask[:, :, -1:, :]
elif attention_mask.ndim == 3:
memory_mask = attention_mask[:, -1:, :].unsqueeze([1])
elif attention_mask.ndim == 2:
memory_mask = attention_mask.unsqueeze([1, 2])
else:
raise ValueError("Invalid attention mask shape. ")

decoder_output = self.decoder(
decoder_input_ids,
decoder_attention_mask,
encoder_last_hidden_state,
attention_mask,
memory_mask,
cache,
decoder_inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
Expand Down
14 changes: 13 additions & 1 deletion paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,14 +1463,25 @@ def forward(
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

encoder_attention_mask = attention_mask
if attention_mask is not None:
if attention_mask.ndim == 4:
encoder_attention_mask = attention_mask[:, :, -1:, :]
elif attention_mask.ndim == 3:
encoder_attention_mask = attention_mask[:, -1:, :].unsqueeze([1])
elif attention_mask.ndim == 2:
encoder_attention_mask = attention_mask.unsqueeze([1, 2])
else:
raise ValueError("Invalid attention mask shape. ")

# Decode
decoder_outputs = self.t5.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
cache=cache,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -1545,6 +1556,7 @@ def prepare_inputs_for_generation(
# cut decoder_input_ids if past is used
if cache is not None:
input_ids = input_ids[:, -1:]

return {
"decoder_input_ids": input_ids,
"cache": cache,
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/unified_transformer/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,8 @@ def prepare_inputs_for_generation(
position_ids = position_ids[:, -1:]
if role_ids is not None:
role_ids = role_ids[:, -1:]
if attention_mask is not None:
attention_mask = attention_mask[:, :, -1:, :]
if attention_mask is not None:
attention_mask = attention_mask[:, :, -1:, :]

return {
"input_ids": input_ids,
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/unimo/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,8 @@ def prepare_inputs_for_generation(
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
if position_ids is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = attention_mask[:, :, -1:, :]
if attention_mask is not None:
attention_mask = attention_mask[:, :, -1:, :]

return {
"input_ids": input_ids,
Expand Down
2 changes: 2 additions & 0 deletions tests/transformers/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def _get_input_ids_and_config(self):
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length].unsqueeze([1, 2])

attention_mask = attention_mask * attention_mask.transpose([0, 1, 3, 2])

# generate max 3 tokens
max_length = 3

Expand Down
2 changes: 1 addition & 1 deletion tests/transformers/unified_transformer/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def test_unified_transformer_sample(self):
)
output_str = postprocess_response(output_ids[0].numpy(), tokenizer)

EXPECTED_OUTPUT_STR = "你 在 做 什么 呢 ?"
EXPECTED_OUTPUT_STR = "你 在 哪里 呀 ?"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)

def test_generate_without_input_ids(self):
Expand Down