Skip to content

Commit

Permalink
add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Feb 29, 2024
1 parent bcb5aea commit fbe3b78
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 2 deletions.
2 changes: 0 additions & 2 deletions paddlenlp/transformers/mixtral/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,8 +1070,6 @@ def __init__(self, config: MixtralConfig):
)
self.norm = MixtralRMSNorm(config)

self.gradient_checkpointing = False

def get_input_embeddings(self):
return self.embed_tokens

Expand Down
13 changes: 13 additions & 0 deletions tests/transformers/mixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
318 changes: 318 additions & 0 deletions tests/transformers/mixtral/test_modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import unittest

import paddle

from paddlenlp.transformers import MixtralConfig, MixtralForCausalLM, MixtralModel
from tests.transformers.test_configuration_common import ConfigTester
from tests.transformers.test_generation_utils import GenerationTesterMixin
from tests.transformers.test_modeling_common import (
ModelTesterMixin,
ids_tensor,
random_attention_mask,
)


class MixtralModelTester:
def __init__(
self,
parent,
vocab_size=32000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
masked_softmax_fusion=True,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
is_training=True,
use_cache=False,
bos_token_id=1,
eos_token_id=2,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0,
attention_dropout=0.0,
attention_softmax_in_fp32=True,
pretraining_tp=1, # TP rank used when training with megatron
dtype="bfloat16",
slow_but_exact=False,
batch_size: int = 2,
seq_length: int = 10,
type_sequence_label_size=2,
activation_function="gelu",
num_labels=3,
num_choices=4,
scope=None,
dropout=0.56,
use_input_mask: bool = False,
use_labels: bool = False,
return_dict=False,
):
self.parent: MixtralModelTest = parent
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.masked_softmax_fusion = masked_softmax_fusion
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.is_training = is_training
self.use_cache = use_cache
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.pretraining_tp = pretraining_tp
self.dtype = dtype
self.slow_but_exact = slow_but_exact

self.batch_size = batch_size
self.seq_length = seq_length
self.type_sequence_label_size = type_sequence_label_size
self.activation_function = activation_function
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.dropout = dropout

self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.return_dict = return_dict

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64)

input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])

sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)

config = self.get_config()
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels

def get_config(self) -> MixtralConfig:
return MixtralConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
masked_softmax_fusion=self.masked_softmax_fusion,
layer_norm_epsilon=self.layer_norm_epsilon,
initializer_range=self.initializer_range,
use_cache=self.use_cache,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
hidden_dropout=self.hidden_dropout,
attention_dropout=self.attention_dropout,
attention_softmax_in_fp32=self.attention_softmax_in_fp32,
pretraining_tp=self.pretraining_tp,
dtype=self.dtype,
slow_but_exact=self.slow_but_exact,
activation_function=self.activation_function,
)

def create_and_check_model(
self, config: MixtralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = MixtralModel(config)
model.eval()
result = model(input_ids)
self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size])

def create_and_check_model_attention_mask(
self, config: MixtralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = MixtralModel(config)
model.eval()
attn_mask_2d = random_attention_mask([self.batch_size, self.seq_length])
result_2d = model(input_ids, attention_mask=attn_mask_2d)[0]
batch, seq_length = input_ids.shape
causal_mask = paddle.tril(paddle.ones((batch, seq_length, seq_length), dtype=attn_mask_2d.dtype))
attn_mask_3d = causal_mask & attn_mask_2d.unsqueeze(-1)
result_3d = model(input_ids, attention_mask=attn_mask_3d)[0]
attn_mask_4d = attn_mask_3d.unsqueeze(1)
result_4d = model(input_ids, attention_mask=attn_mask_4d)[0]
result_no_attention_mask = model(input_ids, attention_mask=None)[0]
# Assert non-padding tokens have the same logits with different attention_mask shape
self.parent.assertTrue((result_2d[attn_mask_2d] == result_3d[attn_mask_2d]).all())
self.parent.assertTrue((result_2d[attn_mask_2d] == result_4d[attn_mask_2d]).all())
self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all())

def create_and_check_model_past_large_inputs(
self,
config: MixtralConfig,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
model = MixtralModel(config)
model.eval()

# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True, return_dict=self.return_dict)
past_key_values = outputs.past_key_values if self.return_dict else outputs[2]

# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), self.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)

# append to next input_ids and
next_input_ids = paddle.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = paddle.concat([input_mask, next_mask], axis=-1)

outputs = model(
next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True, return_dict=self.return_dict
)

output_from_no_past = outputs[2][0]

outputs = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
return_dict=self.return_dict,
)

output_from_past = outputs[2][0]

# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()

self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])

# test that outputs are equal for slice
self.parent.assertTrue(paddle.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict

def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args):
model = MixtralForCausalLM(config)
model.eval()

result = model(
input_ids,
use_cache=True,
labels=input_ids if self.parent.use_labels else None,
return_dict=self.parent.return_dict,
)
if self.parent.use_labels:
self.parent.assertIsInstance(result[0].item(), float)
self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size])
else:
self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size])

def check_model_position_ids(self, config, input_ids, input_mask, *args):
model = MixtralForCausalLM(config)
model.eval()

result_no_position_id = model(
input_ids,
labels=input_ids if self.parent.use_labels else None,
return_dict=self.parent.return_dict,
)
batch_size, seq_len = input_ids.shape
position_ids = paddle.arange(seq_len).expand((batch_size, seq_len))
result_position_id = model(
input_ids,
position_ids,
labels=input_ids if self.parent.use_labels else None,
return_dict=self.parent.return_dict,
)
if self.parent.use_labels:
self.parent.assertTrue((result_position_id[1] == result_no_position_id[1]).all())
else:
self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all())


class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
base_model_class = MixtralModel
return_dict = False
use_labels = False
use_test_model_name_list = False

all_model_classes = (MixtralModel, MixtralForCausalLM)
all_generative_model_classes = {MixtralForCausalLM: (MixtralModel, "mixtral")}

def setUp(self):
super().setUp()

self.model_tester = MixtralModelTester(self)
self.config_tester = ConfigTester(self, config_class=MixtralConfig, vocab_size=256, hidden_size=24)

def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

input_ids = inputs_dict[self.input_name]
attention_mask = paddle.ones_like(input_ids, dtype=paddle.int64)

max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
max_length = 3

return config, input_ids, attention_mask, max_length

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_model_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_attention_mask(*config_and_inputs)

def test_model_position_ids(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_model_position_ids(*config_and_inputs)

def test_generate_without_input_ids(self):
# this requires 4-D attention mask logic, which is not supported yet
pass

def test_mixtral_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)


if __name__ == "__main__":
unittest.main()

0 comments on commit fbe3b78

Please sign in to comment.