forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
_sources.py
137 lines (114 loc) · 4.32 KB
/
_sources.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import ast
import functools
import inspect
from textwrap import dedent
from typing import Any, List, NamedTuple, Optional, Tuple
from torch._C import ErrorReport
from torch._C._jit_tree_views import SourceRangeFactory
def get_source_lines_and_file(
obj: Any,
error_msg: Optional[str] = None,
) -> Tuple[List[str], int, Optional[str]]:
"""
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
Returns: (sourcelines, file_lino, filename)
"""
filename = None # in case getsourcefile throws
try:
filename = inspect.getsourcefile(obj)
sourcelines, file_lineno = inspect.getsourcelines(obj)
except OSError as e:
msg = (
f"Can't get source for {obj}. TorchScript requires source access in "
"order to carry out compilation, make sure original .py files are "
"available."
)
if error_msg:
msg += "\n" + error_msg
raise OSError(msg) from e
return sourcelines, file_lineno, filename
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
"""
This helper function accepts a list of source lines. It finds the
indentation level of the function definition (`def`), then it indents
all lines in the function body to a point at or greater than that
level. This allows for comments and continued string literals that
are at a lower indentation than the rest of the code.
Args:
sourcelines: function source code, separated into lines by
the '\n' character
Returns:
A list of source lines that have been correctly aligned
"""
def remove_prefix(text, prefix):
return text[text.startswith(prefix) and len(prefix) :]
# Find the line and line number containing the function definition
idx = None
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break
# This will happen when the function is a lambda- we won't find "def" anywhere in the source
# lines in that case. Currently trying to JIT compile a lambda will throw an error up in
# `parse_def()`, but we might want to handle this case in the future.
if idx is None:
return sourcelines
# Get a string representing the amount of leading whitespace
fn_def = sourcelines[idx]
whitespace = fn_def.split("def")[0]
# Add this leading whitespace to all lines before and after the `def`
aligned_prefix = [
whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
]
aligned_suffix = [
whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
]
# Put it together again
aligned_prefix.append(fn_def)
return aligned_prefix + aligned_suffix
# Thin wrapper around SourceRangeFactory to store extra metadata
# about the function-to-be-compiled.
class SourceContext(SourceRangeFactory):
def __init__(
self,
source,
filename,
file_lineno,
leading_whitespace_len,
uses_true_division=True,
funcname=None,
):
super().__init__(source, filename, file_lineno, leading_whitespace_len)
self.uses_true_division = uses_true_division
self.filename = filename
self.funcname = funcname
@functools.lru_cache(maxsize=None)
def make_source_context(*args):
return SourceContext(*args)
def fake_range():
return SourceContext("", None, 0, 0).make_raw_range(0, 1)
class ParsedDef(NamedTuple):
ast: ast.Module
ctx: SourceContext
source: str
filename: Optional[str]
file_lineno: int
def parse_def(fn):
sourcelines, file_lineno, filename = get_source_lines_and_file(
fn, ErrorReport.call_stack()
)
sourcelines = normalize_source_lines(sourcelines)
source = "".join(sourcelines)
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(
f"Expected a single top-level function: {filename}:{file_lineno}"
)
leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
dedent_src.split("\n", 1)[0]
)
ctx = make_source_context(
source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
)
return ParsedDef(py_ast, ctx, source, filename, file_lineno)