Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT][Faster Guard] adapt to faster guard for more variables #69411

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,22 @@ void BindGuard(pybind11::module *m) {
.def(py::init<const paddle::framework::proto::VarType &>(),
py::arg("dtype"))
.def(py::init<const phi::DataType &>(), py::arg("dtype"));
py::class_<LayerMatchGuard, GuardBase, std::shared_ptr<LayerMatchGuard>>(
*m, "LayerMatchGuard", R"DOC(LayerMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("layer_obj"));
py::class_<AttributeMatchGuard,
GuardBase,
std::shared_ptr<AttributeMatchGuard>>(
*m, "AttributeMatchGuard", R"DOC(AttributeMatchGuard Class.)DOC")
.def(py::init<const py::object &, const std::string &>(),
py::arg("obj"),
py::arg("attr_name"));
py::class_<ShapeMatchGuard, GuardBase, std::shared_ptr<ShapeMatchGuard>>(
*m, "ShapeMatchGuard", R"DOC(ShapeMatchGuard Class.)DOC")
.def(py::init<const std::vector<py::object> &>(), py::arg("shape"));
py::class_<LayerMatchGuard, GuardBase, std::shared_ptr<LayerMatchGuard>>(
*m, "LayerMatchGuard", R"DOC(LayerMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("layer_obj"));
py::class_<RangeMatchGuard, GuardBase, std::shared_ptr<RangeMatchGuard>>(
*m, "RangeMatchGuard", R"DOC(RangeMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("range_obj"));

m->def(
"merge_guard",
Expand Down
35 changes: 22 additions & 13 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
}
#endif

static inline bool PyObject_Equal(PyObject* a, PyObject* b) {
if (a == b) {
return true;
}
if (Py_TYPE(a) != Py_TYPE(b)) {
return false;
}
int result = PyObject_RichCompareBool(a, b, Py_EQ);
// Check for exception
if (result == -1) {
PyErr_Clear();
return false;
}
return result;
}

std::optional<paddle::Tensor> GetTensorFromPyObject(PyObject* obj) {
if (!paddle::pybind::PyCheckTensor(obj)) {
// TODO(zrr1999): PyCheckTensor only check if the object is a p_tensor_type.
Expand Down Expand Up @@ -59,19 +75,7 @@ bool TypeMatchGuard::check(PyObject* value) {
}

bool ValueMatchGuard::check(PyObject* value) {
if (value == expected_value_) {
return true;
}
if (Py_TYPE(value) != expected_type_) {
return false;
}
int result = PyObject_RichCompareBool(value, expected_value_, Py_EQ);
// Check for exception
if (result == -1) {
PyErr_Clear();
return false;
}
return result;
return PyObject_Equal(value, expected_value_);
}

bool LengthMatchGuard::check(PyObject* value) {
Expand Down Expand Up @@ -104,6 +108,11 @@ bool ShapeMatchGuard::check(PyObject* value) {
return true;
}

bool AttributeMatchGuard::check(PyObject* value) {
PyObject* attr = PyObject_GetAttrString(value, attr_name_.c_str());
return PyObject_Equal(attr, attr_ptr_);
}

bool LayerMatchGuard::check(PyObject* value) {
if (value != layer_ptr_) {
return false;
Expand Down
25 changes: 24 additions & 1 deletion paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ class GuardGroup : public GuardBase {

class TypeMatchGuard : public GuardBase {
public:
explicit TypeMatchGuard(PyTypeObject* type_ptr) : expected_(type_ptr) {}
explicit TypeMatchGuard(PyObject* type_ptr)
: expected_(reinterpret_cast<PyTypeObject*>(type_ptr)) {}

explicit TypeMatchGuard(const py::type& py_type)
: expected_(reinterpret_cast<PyTypeObject*>(py_type.ptr())) {}

Expand Down Expand Up @@ -148,6 +148,19 @@ class ShapeMatchGuard : public GuardBase {
std::vector<std::optional<int64_t>> expected_;
};

class AttributeMatchGuard : public GuardBase {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeMatchGuard(v, "attr").check(x) 会比 ValueMatchGuard(v).check(x.attr) 快一些是么?

这里我的一个考虑是,如果 IR 表示(表达式树)上包含的是 value_guard_1(x.attr) and value_guard_2(x.attr) 是可以消除公共子表达式的,而 attr_guard_1(x, attr) and attr_guard_2(x, attr) 则是相对来说会更难优化的,这在后续设计时可以考虑下

如果 AttributeMatchGuard 是有优化的,这里可以考虑在最开始 build 的 Guard 树上不体现 AttributeMatchGuard,以确保公共子表达式消除等优化是可以作用的,最后通过 Pass 将 value_guard_1(x.attr) fuse 成 attr_guard_1(x, attr)

有点组合算子的意思,IR 上最开始生成的以及后续 Pass 处理的是基础 Guard,经过图变换、融合生成高性能的组合 Guard

public:
AttributeMatchGuard(const py::object& obj, const std::string& attr_name)
: attr_ptr_(PyObject_GetAttrString(obj.ptr(), attr_name.c_str())),
attr_name_(attr_name) {}

bool check(PyObject* value);

private:
PyObject* attr_ptr_;
std::string attr_name_;
};

class LayerMatchGuard : public GuardBase {
public:
explicit LayerMatchGuard(PyObject* layer_ptr) : layer_ptr_(layer_ptr) {
Expand All @@ -164,4 +177,14 @@ class LayerMatchGuard : public GuardBase {
bool training_;
};

class RangeMatchGuard : public GuardGroup {
public:
explicit RangeMatchGuard(const py::object& range_obj)
: GuardGroup({std::make_shared<TypeMatchGuard>(Py_TYPE(range_obj.ptr())),
std::make_shared<AttributeMatchGuard>(range_obj, "start"),
std::make_shared<AttributeMatchGuard>(range_obj, "stop"),
std::make_shared<AttributeMatchGuard>(range_obj, "step")}) {
}
};

#endif
2 changes: 1 addition & 1 deletion python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
free_vars: dict[str, Any],
):
self.faster_guard = faster_guard
if ENV_SOT_ENABLE_FASTER_GUARD:
if ENV_SOT_ENABLE_FASTER_GUARD.get():
original_expr_template = expr_template
guard_cls_name = faster_guard.__class__.__name__
guard_name = f"{guard_cls_name}_{id(faster_guard)}"
Expand Down
26 changes: 15 additions & 11 deletions python/paddle/jit/sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,19 @@
import paddle

from ....profiler import event_register
from ....utils import NameGenerator, get_unbound_method, log
from ....utils import (
NameGenerator,
get_unbound_method,
log,
)
from ....utils.exceptions import FallbackError, HasNoAttributeError
from ..dispatcher import Dispatcher
from ..guard import StringifiedExpression, check_guard, union_free_vars
from ..guard import (
FasterStringifiedExpression,
StringifiedExpression,
check_guard,
union_free_vars,
)
from ..mutable_data import MutableDictLikeData
from ..tracker import (
DummyTracker,
Expand Down Expand Up @@ -364,18 +373,13 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:

# Get a ValueTracer object from the Tracker object associated with the variable
frame_value_tracer = self.tracker.trace_value_from_frame()

return [
StringifiedExpression(
f"id(type({{}})) == {id(self.get_py_type())}",
FasterStringifiedExpression(
f"id(type({{0}})) == {id(self.get_py_type())} and {{0}} == {self.get_py_value()!r}",
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
[frame_value_tracer],
union_free_vars(frame_value_tracer.free_vars),
),
StringifiedExpression(
f"{{}} == {self.get_py_value()!r}",
[frame_value_tracer],
union_free_vars(frame_value_tracer.free_vars),
),
)
]

def get_py_value(self, allow_tensor=False) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from ..dispatcher import Dispatcher
from ..guard import (
FasterStringifiedExpression,
StringifiedExpression,
check_guard,
object_equal_stringified_guard,
Expand Down Expand Up @@ -445,13 +446,9 @@ def call_function(self, /, *args, **kwargs):
def make_stringified_guard(self) -> list[StringifiedExpression]:
frame_value_tracer = self.tracker.trace_value_from_frame()
return [
StringifiedExpression(
f"id({{}}) == {id(self.get_py_value())}",
[frame_value_tracer],
union_free_vars(frame_value_tracer.free_vars),
),
StringifiedExpression(
f"{{}}.training == {self.get_py_value().training}",
FasterStringifiedExpression(
f"id({{0}}) == {id(self.get_py_value())} and {{0}}.training == {self.get_py_value().training}",
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
[frame_value_tracer],
union_free_vars(frame_value_tracer.free_vars),
),
Expand Down Expand Up @@ -503,13 +500,14 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
if isinstance(self.value, PD_SEQ_CONTAINERS):
frame_value_tracer = self.tracker.trace_value_from_frame()

len_guard = StringifiedExpression(
len_guard = FasterStringifiedExpression(
f"len({{}}) == {len(self.value)}",
paddle.framework.core.LengthMatchGuard(len(self.value)),
[frame_value_tracer],
frame_value_tracer.free_vars,
)

guards = [len_guard]
guards: list[StringifiedExpression] = [len_guard]
for idx, layer in enumerate(self.value):
layer_variable = VariableFactory.from_value(
layer, self.graph, GetItemTracker(self, idx)
Expand Down
3 changes: 2 additions & 1 deletion test/sot/test_01_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import unittest

from test_case_base import TestCaseBase
from test_case_base import TestCaseBase, test_with_faster_guard

import paddle

Expand All @@ -24,6 +24,7 @@ def foo(x: int, y: paddle.Tensor):


class TestBasic(TestCaseBase):
@test_with_faster_guard
def test_simple(self):
self.assert_results(foo, 1, paddle.to_tensor(2))

Expand Down
2 changes: 2 additions & 0 deletions test/sot/test_12_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from test_case_base import (
TestCaseBase,
test_instruction_translator_cache_context,
test_with_faster_guard,
)

import paddle
Expand Down Expand Up @@ -224,6 +225,7 @@ def test_for_without_zero_iter(self):
def test_reconstruct_range_iter(self):
self.assert_results(for_reconstruct_range_iter)

@test_with_faster_guard
def test_layer_list(self):
layers = paddle.nn.LayerList()
for i in range(5):
Expand Down
5 changes: 4 additions & 1 deletion test/sot/test_17_paddle_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import unittest

from test_case_base import TestCaseBase
from test_case_base import TestCaseBase, test_with_faster_guard

import paddle

Expand Down Expand Up @@ -66,6 +66,7 @@ def forward(self, x):


class TestLayer(TestCaseBase):
@test_with_faster_guard
def test_layer(self):
x = paddle.rand((10,))
y = paddle.rand((10, 10))
Expand All @@ -74,6 +75,7 @@ def test_layer(self):
self.assert_results(net_call, y, net)
self.assert_results(net_call_passed_by_user, x, net.forward)

@test_with_faster_guard
def test_layer_with_sequential(self):
x = paddle.rand((10,))
y = paddle.rand((10, 10))
Expand All @@ -82,6 +84,7 @@ def test_layer_with_sequential(self):
self.assert_results(net_call, y, net)
self.assert_results(net_call_passed_by_user, x, net.forward)

@test_with_faster_guard
def test_bound(self):
x = paddle.rand((10,))
y = paddle.rand((10, 10))
Expand Down
15 changes: 13 additions & 2 deletions test/sot/test_faster_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import paddle


class TestFasterGuard(unittest.TestCase):
class TestBasicFasterGuard(unittest.TestCase):
def test_lambda_guard(self):
guard_lambda = paddle.framework.core.LambdaGuard(lambda x: x == 1)
self.assertTrue(guard_lambda.check(1))
Expand Down Expand Up @@ -69,6 +69,12 @@ def test_shape_match_guard(self):
guard_shape = paddle.framework.core.ShapeMatchGuard([2, 3, 1])
self.assertFalse(guard_shape.check(tensor))

def test_attribute_match_guard(self):
a = range(1, 10, 2)
guard_attribute = paddle.framework.core.AttributeMatchGuard(a, "start")
self.assertTrue(guard_attribute.check(a))
self.assertFalse(guard_attribute.check(range(10)))

def test_layer_match_guard(self):
layer = paddle.nn.Linear(10, 10)
guard_layer = paddle.framework.core.LayerMatchGuard(layer)
Expand All @@ -90,7 +96,7 @@ def test_guard_group(self):
self.assertTrue(guard_group.check(1))
self.assertFalse(guard_group.check(2))

def test_negated_guard_group(self):
def test_nested_guard_group(self):
guard_lambda = paddle.framework.core.LambdaGuard(lambda x: x == 1)
guard_type_match = paddle.framework.core.TypeMatchGuard(int)
guard_group = paddle.framework.core.GuardGroup(
Expand All @@ -103,6 +109,11 @@ def test_negated_guard_group(self):
self.assertTrue(guard_group.check(1))
self.assertFalse(guard_group.check(2))

def test_range_match_guard(self):
guard_range = paddle.framework.core.RangeMatchGuard(range(1, 10, 2))
self.assertTrue(guard_range.check(range(1, 10, 2)))
self.assertFalse(guard_range.check(range(11)))


if __name__ == "__main__":
unittest.main()