Skip to content

Commit

Permalink
[SOT][Faster Guard] add ENV_SOT_ENABLE_FASTER_GUARD (#69263)
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 authored Nov 11, 2024
1 parent 4ee2f7a commit d466c1c
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 19 deletions.
6 changes: 4 additions & 2 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
ENV_COST_MODEL,
ENV_MIN_GRAPH_SIZE,
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_ENABLE_FASTER_GUARD,
ENV_SOT_EXPORT,
ENV_SOT_LOG_LEVEL,
ENV_SOT_WITH_CONTROL_FLOW,
ENV_STRICT_MODE,
allow_dynamic_shape_guard,
cost_model_guard,
export_guard,
faster_guard_guard,
min_graph_size_guard,
sot_step_profiler_guard,
strict_mode_guard,
with_allow_dynamic_shape_guard,
with_control_flow_guard,
with_export_guard,
)
from .exceptions import ( # noqa: F401
BreakGraphError,
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/jit/sot/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
# Enable SOT dynamic shape as default in PIR mode only
paddle.framework.use_pir_api(),
)
ENV_SOT_ENABLE_FASTER_GUARD = BooleanEnvironmentVariable(
"SOT_ENABLE_FASTER_GUARD",
False,
)
ENV_SOT_EVENT_LEVEL = IntegerEnvironmentVariable("SOT_EVENT_LEVEL", 0)
ENV_ENABLE_SOT_STEP_PROFILER = BooleanEnvironmentVariable(
"ENABLE_SOT_STEP_PROFILER", False
Expand Down Expand Up @@ -69,17 +73,23 @@ def with_control_flow_guard(value: bool):


@contextmanager
def with_export_guard(value: str):
def export_guard(value: str):
with EnvironmentVariableGuard(ENV_SOT_EXPORT, value):
yield


@contextmanager
def with_allow_dynamic_shape_guard(value: bool):
def allow_dynamic_shape_guard(value: bool):
with EnvironmentVariableGuard(ENV_SOT_ALLOW_DYNAMIC_SHAPE, value):
yield


@contextmanager
def faster_guard_guard(value: bool):
with EnvironmentVariableGuard(ENV_SOT_ENABLE_FASTER_GUARD, value):
yield


@contextmanager
def sot_step_profiler_guard(value: bool):
with EnvironmentVariableGuard(ENV_ENABLE_SOT_STEP_PROFILER, value):
Expand Down
20 changes: 10 additions & 10 deletions test/sot/test_sot_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import paddle
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils import (
with_allow_dynamic_shape_guard,
allow_dynamic_shape_guard,
)


Expand Down Expand Up @@ -78,7 +78,7 @@ def forward(self, x):

class TestOpcodeExecutorDynamicShapeCache(TestCaseBase):
def test_dynamic_int_input_cache_hit_case1(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -92,7 +92,7 @@ def test_dynamic_int_input_cache_hit_case1(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_int_input_cache_hit_case2(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -106,7 +106,7 @@ def test_dynamic_int_input_cache_hit_case2(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_int_input_cache_hit_case3(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
for i in range(0, 6):
Expand All @@ -116,7 +116,7 @@ def test_dynamic_int_input_cache_hit_case3(self):
self.assertEqual(ctx.translate_count, i + 1)

def test_dynamic_shape_input_cache_hit_case1(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -130,7 +130,7 @@ def test_dynamic_shape_input_cache_hit_case1(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_shape_input_cache_hit_case2(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -145,7 +145,7 @@ def test_dynamic_shape_input_cache_hit_case2(self):
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_shape_cast(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
func1 = check_no_breakgraph(lambda n: bool(n))
Expand All @@ -156,7 +156,7 @@ def test_dynamic_shape_cast(self):
self.assert_results(func, 2)

def test_dynamic_shape_in_list(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
Expand All @@ -174,7 +174,7 @@ def test_dynamic_shape_in_list(self):
self.assertEqual(ctx.translate_count, 2)

def test_conv_dynamic_shape_fallback(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
for i in range(1, 5):
Expand All @@ -183,7 +183,7 @@ def test_conv_dynamic_shape_fallback(self):
self.assertEqual(ctx.translate_count, i)

def test_pad_dynamic_shape_fallback(self):
with with_allow_dynamic_shape_guard(
with allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
pad_func = check_no_breakgraph(
Expand Down
4 changes: 2 additions & 2 deletions test/sot/test_sot_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import paddle
from paddle.jit.sot.utils import min_graph_size_guard, with_export_guard
from paddle.jit.sot.utils import export_guard, min_graph_size_guard


class Net(paddle.nn.Layer):
Expand All @@ -43,7 +43,7 @@ def test_basic(self):
temp_dir_name = temp_dir.name
net = Net()
x = paddle.to_tensor([2, 3], dtype="float32", stop_gradient=True)
with with_export_guard(temp_dir_name):
with export_guard(temp_dir_name):
y = paddle.jit.to_static(net)(x)
assert os.path.exists(os.path.join(temp_dir_name, "SIR_0.py"))
temp_dir.cleanup()
Expand Down
6 changes: 3 additions & 3 deletions test/sot/test_trace_list_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)

import paddle
from paddle.jit.sot.utils.envs import with_allow_dynamic_shape_guard
from paddle.jit.sot.utils.envs import allow_dynamic_shape_guard


def foo(x: list[paddle.Tensor], y: list[paddle.Tensor]):
Expand All @@ -47,7 +47,7 @@ def test_foo(self):
self.assert_results(foo, [a], [c]) # Cache miss
self.assertEqual(cache.translate_count, 2)

@with_allow_dynamic_shape_guard(False)
@allow_dynamic_shape_guard(False)
def test_bar_static_shape(self):
a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)]
b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)]
Expand All @@ -60,7 +60,7 @@ def test_bar_static_shape(self):
self.assert_results(bar, b, 1, 1) # Cache hit
self.assertEqual(cache.translate_count, 2)

@with_allow_dynamic_shape_guard(True)
@allow_dynamic_shape_guard(True)
def test_bar_dynamic_shape(self):
# TODO(zrr1999): mv to dynamic shape test
a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)]
Expand Down

0 comments on commit d466c1c

Please sign in to comment.