From f9d91c7f10af252151d96d500d51ea0f3546760d Mon Sep 17 00:00:00 2001 From: 0x45f Date: Wed, 10 Jan 2024 07:51:29 +0000 Subject: [PATCH 1/3] [PIR]Open uts for sequence_mask --- test/sequence/test_sequence_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/sequence/test_sequence_mask.py b/test/sequence/test_sequence_mask.py index 57dee2e13bade..61e3ff4da2b4c 100644 --- a/test/sequence/test_sequence_mask.py +++ b/test/sequence/test_sequence_mask.py @@ -24,6 +24,7 @@ convert_np_dtype_to_dtype_, program_guard, ) +from python.paddle.pir_utils import test_with_pir_api def sequence_mask_wraper(x, maxlen_tensor=None, maxlen=-1, mask_dtype='int64'): @@ -168,15 +169,14 @@ def initParameters(self): class TestSequenceMaskOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): with program_guard(Program(), Program()): input_data = np.random.uniform(1, 5, [4]).astype("float32") def test_Variable(): # the input must be Variable - paddle.static.nn.sequence_lod.sequence_mask( - input_data, maxlen=4 - ) + paddle.nn.functional.sequence_mask(input_data, maxlen=4) self.assertRaises(TypeError, test_Variable) From c8355313e4bf6cc3b7e92b03887a275986dfb695 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Wed, 10 Jan 2024 11:32:57 +0000 Subject: [PATCH 2/3] Fix ut --- test/sequence/test_sequence_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/sequence/test_sequence_mask.py b/test/sequence/test_sequence_mask.py index 61e3ff4da2b4c..91baeadfaa235 100644 --- a/test/sequence/test_sequence_mask.py +++ b/test/sequence/test_sequence_mask.py @@ -20,9 +20,7 @@ import paddle from paddle.base.framework import ( - Program, convert_np_dtype_to_dtype_, - program_guard, ) from python.paddle.pir_utils import test_with_pir_api @@ -171,7 +169,9 @@ def initParameters(self): class TestSequenceMaskOpError(unittest.TestCase): @test_with_pir_api def test_errors(self): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input_data = np.random.uniform(1, 5, [4]).astype("float32") def test_Variable(): From 4908f1cb0c08bbf5567c5995bb0abae94f18c2c6 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Thu, 11 Jan 2024 03:06:05 +0000 Subject: [PATCH 3/3] Fix import --- test/sequence/test_sequence_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sequence/test_sequence_mask.py b/test/sequence/test_sequence_mask.py index 91baeadfaa235..43a70a9c835ca 100644 --- a/test/sequence/test_sequence_mask.py +++ b/test/sequence/test_sequence_mask.py @@ -22,7 +22,7 @@ from paddle.base.framework import ( convert_np_dtype_to_dtype_, ) -from python.paddle.pir_utils import test_with_pir_api +from paddle.pir_utils import test_with_pir_api def sequence_mask_wraper(x, maxlen_tensor=None, maxlen=-1, mask_dtype='int64'):