9
9
# obtain one at https://mozilla.org/MPL/2.0/.
10
10
11
11
import ast
12
+ import hashlib
12
13
import inspect
13
14
import math
14
15
import sys
15
- from ast import AST , Constant , Expr , NodeVisitor , UnaryOp , USub
16
+ from ast import Constant , Expr , NodeVisitor , UnaryOp , USub
16
17
from functools import lru_cache
17
18
from pathlib import Path
18
19
from types import ModuleType
19
- from typing import TYPE_CHECKING , AbstractSet , Optional , TypedDict , Union
20
-
21
- from sortedcontainers import SortedSet
20
+ from typing import TYPE_CHECKING , AbstractSet , TypedDict , Union
22
21
22
+ import hypothesis
23
+ from hypothesis .configuration import storage_directory
23
24
from hypothesis .internal .escalation import is_hypothesis_file
24
- from hypothesis .internal .floats import float_to_int
25
25
26
26
if TYPE_CHECKING :
27
27
from typing import TypeAlias
@@ -96,22 +96,55 @@ def visit_Constant(self, node):
96
96
self .generic_visit (node )
97
97
98
98
99
- @ lru_cache ( 1024 )
100
- def constants_from_ast ( tree : AST ) -> AbstractSet [ ConstantT ]:
99
+ def _constants_from_source ( source : Union [ str , bytes ]) -> AbstractSet [ ConstantT ]:
100
+ tree = ast . parse ( source )
101
101
visitor = ConstantVisitor ()
102
102
visitor .visit (tree )
103
103
return visitor .constants
104
104
105
105
106
- @lru_cache (1024 )
107
- def _module_ast (module : ModuleType ) -> Optional [AST ]:
106
+ @lru_cache (4096 )
107
+ def constants_from_module (module : ModuleType ) -> AbstractSet [ConstantT ]:
108
+ try :
109
+ module_file = inspect .getsourcefile (module )
110
+ # use type: ignore because we know this might error
111
+ source_bytes = Path (module_file ).read_bytes () # type: ignore
112
+ except Exception :
113
+ return set ()
114
+
115
+ source_hash = hashlib .sha1 (source_bytes ).hexdigest ()[:16 ]
116
+ cache_p = storage_directory ("constants" ) / source_hash
117
+ try :
118
+ return _constants_from_source (cache_p .read_bytes ())
119
+ except Exception :
120
+ # if the cached location doesn't exist, or it does exist but there was
121
+ # a problem reading it, fall back to standard computation of the constants
122
+ pass
123
+
108
124
try :
109
- source = inspect .getsource (module )
110
- tree = ast .parse (source )
125
+ constants = _constants_from_source (source_bytes )
111
126
except Exception :
112
- return None
127
+ # A bunch of things can go wrong here.
128
+ # * ast.parse may fail on the source code
129
+ # * NodeVisitor may hit a RecursionError (see many related issues on
130
+ # e.g. libcst https://github.com/Instagram/LibCST/issues?q=recursion),
131
+ # or a MemoryError (`"[1, " * 200 + "]" * 200`)
132
+ return set ()
113
133
114
- return tree
134
+ try :
135
+ cache_p .parent .mkdir (parents = True , exist_ok = True )
136
+ cache_p .write_text (
137
+ f"# file: { module_file } \n # hypothesis_version: { hypothesis .__version__ } \n \n "
138
+ # somewhat arbitrary sort order. The cache file doesn't *have* to be
139
+ # stable... but it is aesthetically pleasing, and means we could rely
140
+ # on it in the future!
141
+ + str (sorted (constants , key = lambda v : (str (type (v )), v ))),
142
+ encoding = "utf-8" ,
143
+ )
144
+ except Exception : # pragma: no cover
145
+ pass
146
+
147
+ return constants
115
148
116
149
117
150
@lru_cache (4096 )
@@ -141,7 +174,7 @@ def _is_local_module_file(path: str) -> bool:
141
174
)
142
175
143
176
144
- def local_modules () -> tuple [ModuleType , ... ]:
177
+ def local_modules () -> set [ModuleType ]:
145
178
if sys .platform == "emscripten" : # pragma: no cover
146
179
# pyodide builds bundle the stdlib in a nonstandard location, like
147
180
# `/lib/python312.zip/heapq.py`. To avoid identifying the entirety of
@@ -151,44 +184,15 @@ def local_modules() -> tuple[ModuleType, ...]:
151
184
# pyodide may provide some way to distinguish stdlib/third-party/local
152
185
# code. I haven't looked into it. If they do, we should correctly implement
153
186
# ModuleLocation for pyodide instead of this.
154
- return ()
155
-
156
- # Prevents a `RuntimeError` that can occur when looping over `sys.modules`
157
- # if it's simultaneously modified as a side effect of code in another thread.
158
- # See: https://docs.python.org/3/library/sys.html#sys.modules
159
- modules = sys .modules .copy ().values ()
187
+ return set ()
160
188
161
- return tuple (
189
+ return {
162
190
module
163
- for module in modules
191
+ # copy to avoid a RuntimeError if another thread imports a module while
192
+ # we're iterating.
193
+ for module in sys .modules .copy ().values ()
164
194
if (
165
195
getattr (module , "__file__" , None ) is not None
166
196
and _is_local_module_file (module .__file__ )
167
197
)
168
- )
169
-
170
-
171
- def local_constants () -> ConstantsT :
172
- constants : set [ConstantT ] = set ()
173
- for module in local_modules ():
174
- tree = _module_ast (module )
175
- if tree is None : # pragma: no cover
176
- continue
177
- constants |= constants_from_ast (tree )
178
-
179
- local_constants : ConstantsT = {
180
- "integer" : SortedSet (),
181
- "float" : SortedSet (key = float_to_int ),
182
- "bytes" : SortedSet (),
183
- "string" : SortedSet (),
184
198
}
185
- for value in constants :
186
- choice_type = {
187
- int : "integer" ,
188
- float : "float" ,
189
- bytes : "bytes" ,
190
- str : "string" ,
191
- }[type (value )]
192
- local_constants [choice_type ].add (value ) # type: ignore # hard to type
193
-
194
- return local_constants
0 commit comments