Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT][Faster Guard] add ENV_SOT_ENABLE_FASTER_GUARD #69263

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
ENV_COST_MODEL,
ENV_MIN_GRAPH_SIZE,
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_ENABLE_FASTER_GUARD,
ENV_SOT_EXPORT,
ENV_SOT_LOG_LEVEL,
ENV_SOT_WITH_CONTROL_FLOW,
ENV_STRICT_MODE,
allow_dynamic_shape_guard,
cost_model_guard,
export_guard,
faster_guard_guard,
min_graph_size_guard,
sot_step_profiler_guard,
strict_mode_guard,
with_allow_dynamic_shape_guard,
with_control_flow_guard,
with_export_guard,
)
from .exceptions import ( # noqa: F401
BreakGraphError,
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/jit/sot/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
# Enable SOT dynamic shape as default in PIR mode only
paddle.framework.use_pir_api(),
)
ENV_SOT_ENABLE_FASTER_GUARD = BooleanEnvironmentVariable(
"SOT_ENABLE_FASTER_GUARD",
False,
)
ENV_SOT_EVENT_LEVEL = IntegerEnvironmentVariable("SOT_EVENT_LEVEL", 0)
ENV_ENABLE_SOT_STEP_PROFILER = BooleanEnvironmentVariable(
"ENABLE_SOT_STEP_PROFILER", False
Expand Down Expand Up @@ -69,17 +73,23 @@ def with_control_flow_guard(value: bool):


@contextmanager
def with_export_guard(value: str):
def export_guard(value: str):
with EnvironmentVariableGuard(ENV_SOT_EXPORT, value):
yield


@contextmanager
def with_allow_dynamic_shape_guard(value: bool):
def allow_dynamic_shape_guard(value: bool):
with EnvironmentVariableGuard(ENV_SOT_ALLOW_DYNAMIC_SHAPE, value):
yield


@contextmanager
def faster_guard_guard(value: bool):
with EnvironmentVariableGuard(ENV_SOT_ENABLE_FASTER_GUARD, value):
yield


@contextmanager
def sot_step_profiler_guard(value: bool):
with EnvironmentVariableGuard(ENV_ENABLE_SOT_STEP_PROFILER, value):
Expand Down
20 changes: 10 additions & 10 deletions test/sot/test_sot_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import paddle
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils import (
with_allow_dynamic_shape_guard,
allow_dynamic_shape_guard,
)


Expand Down Expand Up @@ -78,7 +78,7 @@ def forward(self, x):

class TestOpcodeExecutorDynamicShapeCache(TestCaseBase):
def test_dynamic_int_input_cache_hit_case1(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -92,7 +92,7 @@ def test_dynamic_int_input_cache_hit_case1(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_int_input_cache_hit_case2(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -106,7 +106,7 @@ def test_dynamic_int_input_cache_hit_case2(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_int_input_cache_hit_case3(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
for i in range(0, 6):
Expand All @@ -116,7 +116,7 @@ def test_dynamic_int_input_cache_hit_case3(self):
self.assertEqual(ctx.translate_count, i + 1)

def test_dynamic_shape_input_cache_hit_case1(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -130,7 +130,7 @@ def test_dynamic_shape_input_cache_hit_case1(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_shape_input_cache_hit_case2(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -145,7 +145,7 @@ def test_dynamic_shape_input_cache_hit_case2(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_shape_cast(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
func1 = check_no_breakgraph(lambda n: bool(n))
Expand All @@ -156,7 +156,7 @@ def test_dynamic_shape_cast(self):
self.assert_results(func, 2)

def test_dynamic_shape_in_list(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -174,7 +174,7 @@ def test_dynamic_shape_in_list(self):
self.assertEqual(ctx.translate_count, 2)

def test_conv_dynamic_shape_fallback(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
for i in range(1, 5):
Expand All @@ -183,7 +183,7 @@ def test_conv_dynamic_shape_fallback(self):
self.assertEqual(ctx.translate_count, i)

def test_pad_dynamic_shape_fallback(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
pad_func = check_no_breakgraph(
Expand Down
4 changes: 2 additions & 2 deletions test/sot/test_sot_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import paddle
from paddle.jit.sot.utils import min_graph_size_guard, with_export_guard
from paddle.jit.sot.utils import export_guard, min_graph_size_guard


class Net(paddle.nn.Layer):
Expand All @@ -43,7 +43,7 @@ def test_basic(self):
temp_dir_name = temp_dir.name
net = Net()
x = paddle.to_tensor([2, 3], dtype="float32", stop_gradient=True)
with with_export_guard(temp_dir_name):
with export_guard(temp_dir_name):
y = paddle.jit.to_static(net)(x)
assert os.path.exists(os.path.join(temp_dir_name, "SIR_0.py"))
temp_dir.cleanup()
Expand Down
6 changes: 3 additions & 3 deletions test/sot/test_trace_list_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)

import paddle
from paddle.jit.sot.utils.envs import with_allow_dynamic_shape_guard
from paddle.jit.sot.utils.envs import allow_dynamic_shape_guard


def foo(x: list[paddle.Tensor], y: list[paddle.Tensor]):
Expand All @@ -47,7 +47,7 @@ def test_foo(self):
self.assert_results(foo, [a], [c]) # Cache miss
self.assertEqual(cache.translate_count, 2)

@with_allow_dynamic_shape_guard(False)
@allow_dynamic_shape_guard(False)
def test_bar_static_shape(self):
a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)]
b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)]
Expand All @@ -60,7 +60,7 @@ def test_bar_static_shape(self):
self.assert_results(bar, b, 1, 1) # Cache hit
self.assertEqual(cache.translate_count, 2)

@with_allow_dynamic_shape_guard(True)
@allow_dynamic_shape_guard(True)
def test_bar_dynamic_shape(self):
# TODO(zrr1999): mv to dynamic shape test
a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)]
Expand Down