Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][NO.13] pir dy2st unittest fix test_seq2seq - Part 4 #60454

Merged
merged 9 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/dygraph_to_static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_bert)
list(REMOVE_ITEM TEST_OPS test_transformer)
list(REMOVE_ITEM TEST_OPS test_mobile_net)
list(REMOVE_ITEM TEST_OPS test_seq2seq)
endif()

foreach(TEST_OP ${TEST_OPS})
Expand All @@ -40,7 +41,6 @@ endforeach()
set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900)
set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS
"RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 420)
set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150)
set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 240)
set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
Expand All @@ -60,6 +60,7 @@ if(WITH_GPU)
set_tests_properties(test_bert PROPERTIES TIMEOUT 240)
set_tests_properties(test_transformer PROPERTIES TIMEOUT 240)
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 240)
set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 240)
endif()

# Legacy IR only tests for dygraph_to_static
Expand Down
97 changes: 45 additions & 52 deletions test/dygraph_to_static/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_only,
test_sot_mgs0_only,
test_default_and_pir,
)
from seq2seq_dygraph_model import AttentionModel, BaseModel
from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter

import paddle
from paddle import base
from paddle.base.framework import unique_name
from paddle.nn import ClipGradByGlobalNorm

place = (
Expand All @@ -51,9 +50,8 @@ def prepare_input(batch):


def train(args, attn_model=False):
with base.dygraph.guard(place):
paddle.static.default_startup_program().random_seed = 2020
paddle.static.default_main_program().random_seed = 2020
with unique_name.guard():
paddle.seed(2020)

if attn_model:
model = paddle.jit.to_static(
Expand Down Expand Up @@ -142,52 +140,49 @@ def train(args, attn_model=False):


def infer(args, attn_model=False):
with base.dygraph.guard(place):
if attn_model:
model = paddle.jit.to_static(
AttentionModel(
args.hidden_size,
args.src_vocab_size,
args.tar_vocab_size,
args.batch_size,
beam_size=args.beam_size,
num_layers=args.num_layers,
init_scale=args.init_scale,
dropout=0.0,
mode='beam_search',
)
if attn_model:
model = paddle.jit.to_static(
AttentionModel(
args.hidden_size,
args.src_vocab_size,
args.tar_vocab_size,
args.batch_size,
beam_size=args.beam_size,
num_layers=args.num_layers,
init_scale=args.init_scale,
dropout=0.0,
mode='beam_search',
)
else:
model = paddle.jit.to_static(
BaseModel(
args.hidden_size,
args.src_vocab_size,
args.tar_vocab_size,
args.batch_size,
beam_size=args.beam_size,
num_layers=args.num_layers,
init_scale=args.init_scale,
dropout=0.0,
mode='beam_search',
)
)
else:
model = paddle.jit.to_static(
BaseModel(
args.hidden_size,
args.src_vocab_size,
args.tar_vocab_size,
args.batch_size,
beam_size=args.beam_size,
num_layers=args.num_layers,
init_scale=args.init_scale,
dropout=0.0,
mode='beam_search',
)

model_path = (
args.attn_model_path if attn_model else args.base_model_path
)
state_dict = paddle.load(model_path + '.pdparams')
model.set_dict(state_dict)
model.eval()
train_data_iter = get_data_iter(args.batch_size, mode='infer')
for batch_id, batch in enumerate(train_data_iter):
input_data_feed, word_num = prepare_input(batch)
input_data_feed = [
paddle.to_tensor(np_inp) for np_inp in input_data_feed
]
outputs = paddle.jit.to_static(model.beam_search)(input_data_feed)
break

return outputs.numpy()
model_path = args.attn_model_path if attn_model else args.base_model_path
state_dict = paddle.load(model_path + '.pdparams')
model.set_dict(state_dict)
model.eval()
train_data_iter = get_data_iter(args.batch_size, mode='infer')
for batch_id, batch in enumerate(train_data_iter):
input_data_feed, word_num = prepare_input(batch)
input_data_feed = [
paddle.to_tensor(np_inp) for np_inp in input_data_feed
]
outputs = paddle.jit.to_static(model.beam_search)(input_data_feed)
break

return outputs.numpy()


class TestSeq2seq(Dy2StTestBase):
Expand Down Expand Up @@ -238,14 +233,12 @@ def _test_predict(self, attn_model=False):
msg=f"\npred_dygraph = {pred_dygraph} \npred_static = {pred_static}",
)

@test_sot_mgs0_only
@test_legacy_only
@test_default_and_pir
def test_base_model(self):
self._test_train(attn_model=False)
self._test_predict(attn_model=False)

@test_sot_mgs0_only
@test_legacy_only
@test_default_and_pir
def test_attn_model(self):
self._test_train(attn_model=True)
# TODO(liym27): add predict
Expand Down