Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] pir dy2st unittest verification - Part 13 #59517

Merged
merged 6 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
in_dynamic_or_pir_mode,
in_pir_mode,
name_scope,
use_pir_api,
)
from paddle.regularizer import L2Decay

Expand Down Expand Up @@ -788,12 +789,19 @@ def _create_param_lr(self, param_and_grad):
if param_lr == 1.0:
return self._global_learning_rate()
else:
with paddle.static.default_main_program()._lr_schedule_guard(
is_with_opt=True
), framework.name_scope(
'scale_with_param_lr'
):
return self._global_learning_rate() * param_lr
if not use_pir_api():
with paddle.static.default_main_program()._lr_schedule_guard(
is_with_opt=True
), framework.name_scope(
'scale_with_param_lr'
):
return self._global_learning_rate() * param_lr
else:
# TODO(dev): Currently there has not equivalent of op_role in PIR
# mode, so we simply remove _lr_schedule_guard here, this should
# be fixed in the future.
with framework.name_scope('scale_with_param_lr'):
return self._global_learning_rate() * param_lr
Comment on lines +792 to +804
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按理说放在 Program._lr_schedule_guard,但现在 Python 端没相关的 patch,暂时直接这样写了

else:
return self._global_learning_rate()

Expand Down
5 changes: 3 additions & 2 deletions test/dygraph_to_static/test_convert_call_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle.jit import to_static
from paddle.jit.dy2static.convert_call_func import translator_logger


Expand All @@ -38,12 +38,13 @@ def main_func():
class TestConvertGenerator(Dy2StTestBase):
# fallback will ok.
@test_ast_only
@test_legacy_and_pt_and_pir
def test_raise_error(self):
translator_logger.verbosity_level = 1
with self.assertLogs(
translator_logger.logger_name, level='WARNING'
) as cm:
to_static(main_func)()
paddle.jit.to_static(main_func)()
self.assertRegex(
cm.output[0],
"Your function:`dyfunc_generator` doesn't support "
Expand Down
42 changes: 22 additions & 20 deletions test/dygraph_to_static/test_full_name_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_ast_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
)

import paddle
from paddle import base


@paddle.jit.to_static(full_graph=True)
def dygraph_decorated_func(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
if paddle.mean(x) > 0:
x_v = x - 1
else:
Expand All @@ -33,7 +34,7 @@ def dygraph_decorated_func(x):

@paddle.jit.to_static(full_graph=True)
def jit_decorated_func(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
if paddle.mean(x) > 0:
x_v = x - 1
else:
Expand All @@ -50,7 +51,7 @@ class DoubleDecorated:
@classmethod
@paddle.jit.to_static(full_graph=True)
def double_decorated_func1(self, x):
return dygraph_decorated_func(x)
return paddle.jit.to_static(dygraph_decorated_func)(x)

@classmethod
@paddle.jit.to_static(full_graph=True)
Expand All @@ -63,20 +64,21 @@ class TestFullNameDecorator(Dy2StTestBase):
def test_run_success(self):
x = np.ones([1, 2]).astype("float32")
answer = np.zeros([1, 2]).astype("float32")
with base.dygraph.guard():
np.testing.assert_allclose(
dygraph_decorated_func(x).numpy(), answer, rtol=1e-05
)
np.testing.assert_allclose(
jit_decorated_func(x).numpy(), answer, rtol=1e-05
)
np.testing.assert_allclose(
decorated_call_decorated(x).numpy(), answer, rtol=1e-05
)
with self.assertRaises((NotImplementedError, TypeError)):
DoubleDecorated().double_decorated_func1(x)
with self.assertRaises((NotImplementedError, TypeError)):
DoubleDecorated().double_decorated_func2(x)
np.testing.assert_allclose(
paddle.jit.to_static(dygraph_decorated_func)(x).numpy(),
answer,
rtol=1e-05,
)
np.testing.assert_allclose(
jit_decorated_func(x).numpy(), answer, rtol=1e-05
)
np.testing.assert_allclose(
decorated_call_decorated(x).numpy(), answer, rtol=1e-05
)
with self.assertRaises((NotImplementedError, TypeError)):
DoubleDecorated().double_decorated_func1(x)
with self.assertRaises((NotImplementedError, TypeError)):
DoubleDecorated().double_decorated_func2(x)


if __name__ == '__main__':
Expand Down
53 changes: 33 additions & 20 deletions test/dygraph_to_static/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_ast_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle.static import InputSpec
Expand All @@ -26,7 +30,6 @@
np.random.seed(SEED)


@paddle.jit.to_static
def test_slice_without_control_flow(x):
# Python slice will not be transformed.
x = paddle.to_tensor(x)
Expand All @@ -35,7 +38,6 @@ def test_slice_without_control_flow(x):
return a[0]


@paddle.jit.to_static
def test_slice_in_if(x):
x = paddle.to_tensor(x)
a = []
Expand Down Expand Up @@ -70,7 +72,6 @@ def test_slice_in_while_loop(x, iter_num=3):
return out[0]


@paddle.jit.to_static
def test_slice_in_for_loop(x, iter_num=3):
x = paddle.to_tensor(x)
a = []
Expand All @@ -88,7 +89,6 @@ def test_slice_in_for_loop(x, iter_num=3):
return out


@paddle.jit.to_static
def test_set_value(x):
x = paddle.to_tensor(x)
x[0] = paddle.full(shape=[1], fill_value=2, dtype="float32")
Expand All @@ -101,29 +101,25 @@ def __init__(self, input_dim, hidden):
super().__init__()
self.linear = paddle.nn.Linear(input_dim, hidden)

@paddle.jit.to_static
def forward(self, x):
x = self.linear(x)
x[0] = 1
return x


class TestSliceWithoutControlFlow(Dy2StTestBase):
class TestSliceBase(Dy2StTestBase):
def setUp(self):
self.init_input()
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.init_dygraph_func()
self.dygraph_func = None
paddle.disable_static()

def init_input(self):
self.input = np.random.random(3).astype('int32')

def init_dygraph_func(self):
self.dygraph_func = test_slice_without_control_flow
raise NotImplementedError(
"For Enumerate test should implement set_test_func"
)

def run_dygraph_mode(self):
return self._run(to_static=False)
Expand All @@ -140,28 +136,41 @@ def _run(self, to_static):
def run_static_mode(self):
return self._run(to_static=True)


class TestSliceWithoutControlFlow(TestSliceBase):
def init_dygraph_func(self):
self.dygraph_func = test_slice_without_control_flow

@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


class TestSliceInIf(TestSliceWithoutControlFlow):
class TestSliceInIf(TestSliceBase):
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_if

def test_transformed_static_result(self):
self.init_dygraph_func()
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


class TestSliceInWhileLoop(TestSliceWithoutControlFlow):
class TestSliceInWhileLoop(TestSliceInIf):
def init_dygraph_func(self):
self.dygraph_func = paddle.jit.to_static(test_slice_in_while_loop)
self.dygraph_func = test_slice_in_while_loop


class TestSliceInForLoop(TestSliceWithoutControlFlow):
class TestSliceInForLoop(TestSliceInIf):
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_for_loop


class TestSetValue(TestSliceWithoutControlFlow):
class TestSetValue(TestSliceInIf):
def init_input(self):
self.input = np.full([3, 4, 5], 5).astype('float32')

Expand All @@ -182,7 +191,7 @@ def tearDown(self):
@test_ast_only
def test_set_value_with_save(self):
paddle.jit.enable_to_static(True)
model = LayerWithSetValue(input_dim=10, hidden=1)
model = paddle.jit.to_static(LayerWithSetValue(input_dim=10, hidden=1))
x = paddle.full(shape=[5, 10], fill_value=5.0, dtype="float32")
paddle.jit.save(
layer=model, path=self.model_path, input_spec=[x], output_spec=None
Expand All @@ -191,6 +200,7 @@ def test_set_value_with_save(self):

class TestSliceSupplementSpecialCase(Dy2StTestBase):
# unittest for slice index which abs(step)>0. eg: x[::2]
@test_legacy_and_pt_and_pir
def test_static_slice_step(self):
paddle.enable_static()
array = np.arange(4**3).reshape((4, 4, 4)).astype('int64')
Expand All @@ -209,6 +219,7 @@ def test_static_slice_step(self):
np.testing.assert_array_equal(out[0], array[::2])
np.testing.assert_array_equal(out[1], array[::-2])

@test_legacy_and_pt_and_pir
def test_static_slice_step_dygraph2static(self):
paddle.disable_static()

Expand All @@ -233,6 +244,7 @@ def func(inps):


class TestPaddleStridedSlice(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_compare_paddle_strided_slice_with_numpy(self):
paddle.disable_static()
array = np.arange(5)
Expand Down Expand Up @@ -294,6 +306,7 @@ def slice_zero_shape_tensor(x):


class TestSliceZeroShapeTensor(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_slice(self):
paddle.disable_static()
x = paddle.ones([0, 0, 0, 0])
Expand Down
2 changes: 2 additions & 0 deletions test/dygraph_to_static/test_spec_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
Expand Down Expand Up @@ -48,6 +49,7 @@ def read_from_dataset(self):
self.n = paddle.randn([4, 2, 8])

@test_ast_only
@test_legacy_and_pt_and_pir
def test_spec_name_hash(self):
net = Net()
net = paddle.jit.to_static(net)
Expand Down
Loading