Skip to content

Commit 38df973

Browse files
authored
fix some bugs (apache#30)
2 parents eaccf13 + 7fea895 commit 38df973

File tree

11 files changed

+180
-21
lines changed

11 files changed

+180
-21
lines changed

frontend/c_api.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,7 @@ def set_cell(cell: CellType, value: Any) -> None:
112112

113113
def set_local(frame: FrameType, idx: int, value: Any) -> None:
114114
pass
115+
116+
117+
def parse_type_obj(obj: Any) -> str:
118+
pass

frontend/csrc/csrc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,6 @@ PyObject *parse_mapproxyobject(PyObject *self, PyObject *args);
5858
PyObject *parse_mapobject(PyObject *self, PyObject *args);
5959
PyObject *parse_cell(PyObject *self, PyObject *args);
6060
PyObject *set_cell(PyObject *self, PyObject *args);
61+
PyObject *parse_type_obj(PyObject *self, PyObject *args);
6162

6263
} // namespace frontend_csrc

frontend/csrc/frame_evaluation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ static PyMethodDef _methods[] = {
696696
{"parse_mapobject", frontend_csrc::parse_mapobject, METH_VARARGS, NULL},
697697
{"parse_cell", frontend_csrc::parse_cell, METH_VARARGS, NULL},
698698
{"set_cell", frontend_csrc::set_cell, METH_VARARGS, NULL},
699+
{"parse_type_obj", frontend_csrc::parse_type_obj, METH_VARARGS, NULL},
699700
{NULL, NULL, 0, NULL},
700701
};
701702

frontend/csrc/parse_types.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,15 @@ PyObject *set_cell(PyObject *self, PyObject *args) {
110110
return Py_None;
111111
}
112112

113+
PyObject *parse_type_obj(PyObject *self, PyObject *args) {
114+
PyObject *obj;
115+
if (!PyArg_ParseTuple(args, "O", &obj)) {
116+
return NULL;
117+
}
118+
if (PyType_Check(obj)) {
119+
return PyUnicode_FromString(((PyTypeObject *)obj)->tp_name);
120+
}
121+
PyErr_SetString(PyExc_TypeError, "Expected type object");
122+
return NULL;
123+
}
113124
} // namespace frontend_csrc

frontend/guard_tracker.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,11 @@ def add_submodule(self, module: torch.nn.Module) -> None:
145145
self.update_subpath(module, new_module_name)
146146
# self.written = True # not mark as written as graph break may happen
147147

148-
def add_subparam(self, param: torch.nn.Parameter) -> None:
148+
def add_subparam(self, param: torch.nn.Parameter) -> str:
149149
new_param_name = "external_param__" + str(len(self.subparam_paths))
150150
self.root.register_parameter(new_param_name, param)
151151
self.subparam_paths[param] = new_param_name
152+
return new_param_name
152153

153154
def as_node_args_kwargs(
154155
self, args: list[Any], kwargs: dict[str, Any]
@@ -172,6 +173,11 @@ def as_fx_node(arg: Any) -> NodeArgs:
172173
if isinstance(arg, slice):
173174
return slice(as_fx_node(arg.start), as_fx_node(arg.stop),
174175
as_fx_node(arg.step))
176+
if isinstance(arg, np.ndarray):
177+
param_name = self.add_subparam(
178+
torch.nn.Parameter(torch.tensor(arg), requires_grad=False))
179+
return self.fx_graph.create_node("get_attr", param_name, (), {})
180+
175181
var = self.objects.get(arg,
176182
allow_unexist_const=True,
177183
fx_graph=self.fx_graph)
@@ -192,6 +198,9 @@ def as_fx_node(arg: Any) -> NodeArgs:
192198
else:
193199
# TODO: record all operation in SymInt or SymFloat
194200
pass
201+
202+
if f"{type(arg).__module__}.{type(arg).__qualname__}" == "torch.tensortype": # torch.LongTensor
203+
return f"torch.{arg.__name__}"
195204
return var.as_fx_node()
196205

197206
if isinstance(args, torch.Tensor):
@@ -225,6 +234,19 @@ def record_function(self,
225234
add_partial_var: bool = True,
226235
inplace_ref: Any = None,
227236
force_new_value: bool = False) -> None:
237+
if hasattr(func, '__self__') and isinstance(
238+
func.__self__, torch.autograd.grad_mode.no_grad):
239+
if func.__name__ == '__enter__':
240+
target_state = False
241+
elif func.__name__ == '__exit__':
242+
target_state = func.__self__.prev
243+
else:
244+
raise ValueError(func)
245+
args = [
246+
target_state,
247+
]
248+
func = torch._C._set_grad_enabled
249+
kwargs = {}
228250
pargs, pkwargs = self.as_node_args_kwargs(args, kwargs)
229251
if func in fx_graph_inplace_functions:
230252
scalar = None
@@ -268,6 +290,8 @@ def record_function(self,
268290
func = func_dict[func]
269291
if func in math2torch:
270292
func = math2torch[func]
293+
if func == torch.from_numpy:
294+
func = torch.tensor
271295

272296
self.written = True
273297
scalar2tensor: dict[Callable[..., Any], Callable[..., Any]] = {
@@ -1360,7 +1384,6 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:
13601384

13611385
self.state.inplace_update_objs.clear()
13621386
self.state.partial_var.clear()
1363-
print("clear partial var")
13641387
self.state.written = False
13651388
self.state.unmark_calling_func()
13661389
# print('process last instruction done')
@@ -1418,6 +1441,15 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
14181441
return func in (dict, tuple, set, list, hasattr, slice, range, len,
14191442
type)
14201443

1444+
def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
1445+
print(dir(func))
1446+
if (hasattr(func, '__module__') and 'numpy' in func.__module__ and
1447+
'random' not in func.__module__):
1448+
return True
1449+
if type(func) == np.ufunc:
1450+
return True
1451+
return False
1452+
14211453
def get_live_objs(self, pc: int = -1) -> list[tuple[str, Any]]:
14221454
if pc == -1:
14231455
pc = self.frame.f_lasti // 2
@@ -1603,6 +1635,8 @@ def set_if_inplace_return() -> None:
16031635
return
16041636
elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList):
16051637
return
1638+
elif self.is_numpy_constant_func(func):
1639+
return
16061640
elif self.has_unknown_arg(args, kwargs):
16071641
print(
16081642
f"func is {func}, {is_user_defined_func(func)}, args: {args}, kwargs:{kwargs}"
@@ -1789,7 +1823,9 @@ def SETUP_FINALLY(self, _inst: Instruction) -> None:
17891823
pass
17901824

17911825
def SETUP_WITH(self, _inst: Instruction) -> None:
1792-
pass
1826+
mgr = get_value_stack_from_top(self.frame, 0)
1827+
if type(mgr) == torch.autograd.grad_mode.no_grad:
1828+
self.call_function(mgr.__enter__, [], {})
17931829

17941830
# def WITH_EXCEPT_START(self, _inst: Instruction) -> None:
17951831
# pass
@@ -1873,9 +1909,9 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
18731909
if inst.argval in obj_var.modified_attrs:
18741910
return
18751911
need_guard_check = obj_var.need_guard_check
1876-
if obj == self.state.varargs and inst.argval in dir(tuple):
1912+
if id(obj) == id(self.state.varargs) and inst.argval in dir(tuple):
18771913
need_guard_check = False
1878-
if obj == self.state.varkw and inst.argval in dir(dict):
1914+
if id(obj) == id(self.state.varkw) and inst.argval in dir(dict):
18791915
need_guard_check = False
18801916
if config.get_config('dynshape') and isinstance(
18811917
obj, torch.Tensor) and inst.argval == 'shape':
@@ -1957,7 +1993,8 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
19571993
'__self__') and func.__self__ is not None and not isinstance(
19581994
func.__self__, ModuleType):
19591995
args = [func.__self__] + list(args)
1960-
# print(f"function kw: {func}, type: {type(func)},args:{args}, kwargs:{kwargs}")
1996+
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
1997+
self.state.fetch_function_parameters(obj)
19611998
self.call_function(func, args, kwargs)
19621999

19632000
def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
@@ -1973,6 +2010,9 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
19732010
'__self__') and func.__self__ is not None and not isinstance(
19742011
func.__self__, ModuleType):
19752012
args = [func.__self__] + list(args)
2013+
if not isinstance(args, torch.Tensor): # call(*x)
2014+
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
2015+
self.state.fetch_function_parameters(obj)
19762016
self.call_function(func, args, kwargs)
19772017

19782018
def STORE_FAST(self, inst: Instruction) -> None:
@@ -2076,19 +2116,15 @@ def IMPORT_FROM(self, inst: Instruction) -> None:
20762116
pass
20772117

20782118
def UNPACK_SEQUENCE(self, inst: Instruction) -> None:
2079-
seq = get_value_stack_from_top(self.frame, 0)
2080-
if isinstance(seq, (tuple, list)):
2081-
self.state.set_partial_var({
2082-
-1: [
2083-
PartialVar(node=None,
2084-
need_guard_check=False,
2085-
extract_code_at_start=[],
2086-
make_var_fn=vs.make_var_from_value)
2087-
for _ in range(len(seq))
2088-
]
2089-
})
2090-
else:
2091-
raise NotImplementedError
2119+
self.state.set_partial_var({
2120+
-1: [
2121+
PartialVar(node=None,
2122+
need_guard_check=False,
2123+
extract_code_at_start=[],
2124+
make_var_fn=vs.make_var_from_value)
2125+
for _ in range(inst.argval)
2126+
]
2127+
})
20922128

20932129
def UNPACK_EX(self, inst: Instruction) -> None:
20942130
seq = get_value_stack_from_top(self.frame, 0)

frontend/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch._C
1313
import collections
1414
from .config import get_config, set_config
15+
from .c_api import parse_type_obj
1516

1617
if TYPE_CHECKING:
1718
from .instruction import Instruction
@@ -202,6 +203,12 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
202203
assert hasattr(func, '__self__')
203204
return is_user_defined_func(func.__self__)
204205

206+
if inspect.isclass(func):
207+
tp_name = parse_type_obj(func)
208+
module = tp_name.split(".")[0]
209+
if module in ("itertools",):
210+
return False
211+
205212
if func is super:
206213
return False
207214

@@ -393,7 +400,7 @@ def enable_dyn_shape() -> Iterator[None]:
393400

394401

395402
def is_high_order_func(func: Callable[..., Any]) -> bool:
396-
return func in high_order_func_list
403+
return func in high_order_func_list or isinstance(func, Generator)
397404

398405

399406
def is_high_order_func_with_udf(func: Callable[..., Any], args: List[Any],
@@ -431,5 +438,7 @@ def call_user_defined_iterator(x: Any) -> bool:
431438
return len(args) >= 1 and is_user_defined_iter(args[0])
432439
elif func == enumerate:
433440
return len(args) >= 1 and is_user_defined_iter(args[0])
441+
elif isinstance(func, Generator):
442+
return True
434443
else:
435444
raise NotImplementedError

test/test_call_function_ex.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,22 @@ def test_call_ex_with_update(caplog):
9696
compiled = compile(outer_call_ex_with_update)
9797
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
9898
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)
99+
100+
101+
def callee_kw(a, b):
102+
return a[0] + b
103+
104+
105+
def caller_kw(a, b):
106+
return callee_kw((a, 2), b=b)
107+
108+
109+
def test_caller_kw(caplog):
110+
reset()
111+
with torch.no_grad():
112+
a = 1
113+
b = 3
114+
expect = caller_kw(a, b)
115+
compiled = compile(caller_kw)
116+
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
117+
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)

test/test_list.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from frontend.compile import compile, reset
2-
from common.checker import run_and_check, HIT, MISS, assert_equal
2+
from common.checker import run_and_check, HIT, MISS, ALL_MISS, assert_equal
33
import torch
44
import numpy as np
55

@@ -204,3 +204,20 @@ def test_list_inplace(caplog):
204204
expect = list_inplace()
205205
run_and_check(compiled, [MISS], 1, caplog, expect)
206206
run_and_check(compiled, [HIT], 1, caplog, expect)
207+
208+
209+
# def unpack_list(a, b):
210+
# a, b = (y + 1 for y in [a,b])
211+
# return a + b
212+
213+
# def test_unpack_list(caplog):
214+
# reset()
215+
# compiled = compile(unpack_list)
216+
# expect = unpack_list(1, 2)
217+
# run_and_check(compiled, [ALL_MISS], 1, caplog, expect, 1,2)
218+
# run_and_check(compiled, [HIT], 1, caplog, expect, 1, 2)
219+
# a = torch.rand((2,2))
220+
# b = torch.rand((2,2))
221+
# expect = unpack_list(a, b)
222+
# run_and_check(compiled, [ALL_MISS], 2, caplog, expect, a, b)
223+
# run_and_check(compiled, [HIT], 2, caplog, expect, a, b)

test/test_numpy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,19 @@ def test_numpy_to_int(caplog):
3030
result = numpy_to_int(10)
3131
run_and_check(compiled_numpy_to_int, [MISS], 1, caplog, result, 10)
3232
run_and_check(compiled_numpy_to_int, [HIT], 1, caplog, result, 10)
33+
34+
35+
def numpy_to_torch(x):
36+
y = np.floor((x - 1) / 2)
37+
return torch.tensor(y)
38+
39+
40+
def test_numpy_to_torch(caplog):
41+
from frontend.utils import SetConfig
42+
with SetConfig({"backend": "eager"}):
43+
reset()
44+
compiled = compile(numpy_to_torch)
45+
a = np.array([1, 2.0, 3.33])
46+
result = numpy_to_torch(a)
47+
run_and_check(compiled, [MISS], 1, caplog, result, a)
48+
run_and_check(compiled, [HIT], 1, caplog, result, a)

test/test_scalar.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,18 @@ def test_dynamic_scalar_from_tensor(caplog):
225225
bb = torch.tensor(5.0)
226226
expect = dynamic_scalar_from_tensor(aa, bb, c)
227227
run_and_check(compiled, [HIT], 1, caplog, expect, aa, bb, c)
228+
229+
230+
def itertools_product(a, b):
231+
import itertools
232+
return list(itertools.product(a, b))
233+
234+
235+
def test_itertools_product(caplog):
236+
reset()
237+
a = [1, 2]
238+
b = [3, 4]
239+
expect = itertools_product(a, b)
240+
compiled = compile(itertools_product)
241+
run_and_check(compiled, [MISS], 1, caplog, expect, a, b)
242+
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)

0 commit comments

Comments
 (0)