Skip to content

Commit 84fedc3

Browse files
authored
Merge pull request #4366 from tybug/constants-speedup
Improve constants collection speed
2 parents 48edfbf + 1773eec commit 84fedc3

File tree

7 files changed

+180
-87
lines changed

7 files changed

+180
-87
lines changed

hypothesis-python/RELEASE.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch makes the new features introduced in :ref:`version 6.131.1 <v6.131.1>` much
4+
faster, and fixes an internal ``RecursionError`` when working with deeply-nested code.

hypothesis-python/src/hypothesis/_settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def note_deprecation(
919919
settings.register_profile("ci", CI)
920920

921921

922-
if is_in_ci():
922+
if is_in_ci(): # pragma: no cover # covered in ci, but not locally
923923
settings.load_profile("ci")
924924

925925
assert settings.default is not None

hypothesis-python/src/hypothesis/internal/conjecture/engine.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def on_pareto_evict(self, data: ConjectureResult) -> None:
657657
self.settings.database.delete(self.pareto_key, choices_to_bytes(data.choices))
658658

659659
def generate_novel_prefix(self) -> tuple[ChoiceT, ...]:
660-
"""Uses the tree to proactively generate a starting sequence of bytes
660+
"""Uses the tree to proactively generate a starting choice sequence
661661
that we haven't explored yet for this test.
662662
663663
When this method is called, we assume that there must be at
@@ -1047,10 +1047,8 @@ def generate_new_examples(self) -> None:
10471047
ran_optimisations = False
10481048

10491049
while self.should_generate_more():
1050-
# Unfortunately generate_novel_prefix still operates in terms of
1051-
# a buffer and uses HypothesisProvider as its backing provider,
1052-
# not whatever is specified by the backend. We can improve this
1053-
# once more things are on the ir.
1050+
# we don't yet integrate DataTree with backends. Instead of generating
1051+
# a novel prefix, ask the backend for an input.
10541052
if not self.using_hypothesis_backend:
10551053
data = self.new_conjecture_data([])
10561054
with suppress(BackendCannotProceed):

hypothesis-python/src/hypothesis/internal/conjecture/providers.py

+43-17
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import math
1414
import warnings
1515
from collections.abc import Iterable
16+
from functools import cached_property
1617
from random import Random
1718
from sys import float_info
19+
from types import ModuleType
1820
from typing import (
1921
TYPE_CHECKING,
2022
Any,
@@ -45,7 +47,7 @@
4547
Sampler,
4648
many,
4749
)
48-
from hypothesis.internal.constants_ast import local_constants
50+
from hypothesis.internal.constants_ast import constants_from_module, local_modules
4951
from hypothesis.internal.floats import (
5052
SIGNALING_NAN,
5153
float_to_int,
@@ -197,21 +199,40 @@
197199
}
198200

199201

200-
_local_constants_hash: Optional[int] = None
201-
202-
203-
def _get_local_constants():
204-
global _local_constants_hash
205-
206-
constants = local_constants()
207-
constants_hash = hash(tuple((k, tuple(v)) for k, v in constants.items()))
208-
# if we've added new constants since the last time we checked, invalidate
209-
# the cache.
210-
if constants_hash != _local_constants_hash:
211-
CONSTANTS_CACHE.cache.clear()
212-
_local_constants_hash = constants_hash
202+
_local_constants: "ConstantsT" = {
203+
"integer": SortedSet(),
204+
"float": SortedSet(key=float_to_int),
205+
"bytes": SortedSet(),
206+
"string": SortedSet(),
207+
}
208+
# modules that we've already seen and processed for local constants.
209+
_local_modules: set[ModuleType] = set()
210+
211+
212+
def _get_local_constants() -> "ConstantsT":
213+
new_constants: set[ConstantT] = set()
214+
new_modules = list(local_modules() - _local_modules)
215+
for new_module in new_modules:
216+
new_constants |= constants_from_module(new_module)
217+
218+
for constant in new_constants:
219+
choice_type = {
220+
int: "integer",
221+
float: "float",
222+
bytes: "bytes",
223+
str: "string",
224+
}[type(constant)]
225+
# if we add any new constant, invalidate the constant cache for permitted values.
226+
# A more efficient approach would be invalidating just the keys with this
227+
# choice_type.
228+
if (
229+
constant not in _local_constants[choice_type] # type: ignore
230+
): # pragma: no branch
231+
CONSTANTS_CACHE.cache.clear()
232+
_local_constants[choice_type].add(constant) # type: ignore
213233

214-
return constants
234+
_local_modules.update(new_modules)
235+
return _local_constants
215236

216237

217238
class _BackendInfoMsg(TypedDict):
@@ -371,9 +392,13 @@ class HypothesisProvider(PrimitiveProvider):
371392

372393
def __init__(self, conjecturedata: Optional["ConjectureData"], /):
373394
super().__init__(conjecturedata)
374-
self.local_constants = _get_local_constants()
375395
self._random = None if self._cd is None else self._cd._random
376396

397+
@cached_property
398+
def _local_constants(self):
399+
# defer computation of local constants until/if we need it
400+
return _get_local_constants()
401+
377402
def _maybe_draw_constant(
378403
self,
379404
choice_type: ChoiceTypeT,
@@ -382,6 +407,7 @@ def _maybe_draw_constant(
382407
p: float = 0.05,
383408
) -> Optional["ConstantT"]:
384409
assert self._random is not None
410+
assert self._local_constants is not None
385411
assert choice_type != "boolean"
386412

387413
# check whether we even want a constant before spending time computing
@@ -399,7 +425,7 @@ def _maybe_draw_constant(
399425
),
400426
tuple(
401427
choice
402-
for choice in self.local_constants[choice_type]
428+
for choice in self._local_constants[choice_type]
403429
if choice_permitted(choice, constraints)
404430
),
405431
)

hypothesis-python/src/hypothesis/internal/constants_ast.py

+52-48
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import ast
12+
import hashlib
1213
import inspect
1314
import math
1415
import sys
15-
from ast import AST, Constant, Expr, NodeVisitor, UnaryOp, USub
16+
from ast import Constant, Expr, NodeVisitor, UnaryOp, USub
1617
from functools import lru_cache
1718
from pathlib import Path
1819
from types import ModuleType
19-
from typing import TYPE_CHECKING, AbstractSet, Optional, TypedDict, Union
20-
21-
from sortedcontainers import SortedSet
20+
from typing import TYPE_CHECKING, AbstractSet, TypedDict, Union
2221

22+
import hypothesis
23+
from hypothesis.configuration import storage_directory
2324
from hypothesis.internal.escalation import is_hypothesis_file
24-
from hypothesis.internal.floats import float_to_int
2525

2626
if TYPE_CHECKING:
2727
from typing import TypeAlias
@@ -96,22 +96,55 @@ def visit_Constant(self, node):
9696
self.generic_visit(node)
9797

9898

99-
@lru_cache(1024)
100-
def constants_from_ast(tree: AST) -> AbstractSet[ConstantT]:
99+
def _constants_from_source(source: Union[str, bytes]) -> AbstractSet[ConstantT]:
100+
tree = ast.parse(source)
101101
visitor = ConstantVisitor()
102102
visitor.visit(tree)
103103
return visitor.constants
104104

105105

106-
@lru_cache(1024)
107-
def _module_ast(module: ModuleType) -> Optional[AST]:
106+
@lru_cache(4096)
107+
def constants_from_module(module: ModuleType) -> AbstractSet[ConstantT]:
108+
try:
109+
module_file = inspect.getsourcefile(module)
110+
# use type: ignore because we know this might error
111+
source_bytes = Path(module_file).read_bytes() # type: ignore
112+
except Exception:
113+
return set()
114+
115+
source_hash = hashlib.sha1(source_bytes).hexdigest()[:16]
116+
cache_p = storage_directory("constants") / source_hash
117+
try:
118+
return _constants_from_source(cache_p.read_bytes())
119+
except Exception:
120+
# if the cached location doesn't exist, or it does exist but there was
121+
# a problem reading it, fall back to standard computation of the constants
122+
pass
123+
108124
try:
109-
source = inspect.getsource(module)
110-
tree = ast.parse(source)
125+
constants = _constants_from_source(source_bytes)
111126
except Exception:
112-
return None
127+
# A bunch of things can go wrong here.
128+
# * ast.parse may fail on the source code
129+
# * NodeVisitor may hit a RecursionError (see many related issues on
130+
# e.g. libcst https://github.com/Instagram/LibCST/issues?q=recursion),
131+
# or a MemoryError (`"[1, " * 200 + "]" * 200`)
132+
return set()
113133

114-
return tree
134+
try:
135+
cache_p.parent.mkdir(parents=True, exist_ok=True)
136+
cache_p.write_text(
137+
f"# file: {module_file}\n# hypothesis_version: {hypothesis.__version__}\n\n"
138+
# somewhat arbitrary sort order. The cache file doesn't *have* to be
139+
# stable... but it is aesthetically pleasing, and means we could rely
140+
# on it in the future!
141+
+ str(sorted(constants, key=lambda v: (str(type(v)), v))),
142+
encoding="utf-8",
143+
)
144+
except Exception: # pragma: no cover
145+
pass
146+
147+
return constants
115148

116149

117150
@lru_cache(4096)
@@ -141,7 +174,7 @@ def _is_local_module_file(path: str) -> bool:
141174
)
142175

143176

144-
def local_modules() -> tuple[ModuleType, ...]:
177+
def local_modules() -> set[ModuleType]:
145178
if sys.platform == "emscripten": # pragma: no cover
146179
# pyodide builds bundle the stdlib in a nonstandard location, like
147180
# `/lib/python312.zip/heapq.py`. To avoid identifying the entirety of
@@ -151,44 +184,15 @@ def local_modules() -> tuple[ModuleType, ...]:
151184
# pyodide may provide some way to distinguish stdlib/third-party/local
152185
# code. I haven't looked into it. If they do, we should correctly implement
153186
# ModuleLocation for pyodide instead of this.
154-
return ()
155-
156-
# Prevents a `RuntimeError` that can occur when looping over `sys.modules`
157-
# if it's simultaneously modified as a side effect of code in another thread.
158-
# See: https://docs.python.org/3/library/sys.html#sys.modules
159-
modules = sys.modules.copy().values()
187+
return set()
160188

161-
return tuple(
189+
return {
162190
module
163-
for module in modules
191+
# copy to avoid a RuntimeError if another thread imports a module while
192+
# we're iterating.
193+
for module in sys.modules.copy().values()
164194
if (
165195
getattr(module, "__file__", None) is not None
166196
and _is_local_module_file(module.__file__)
167197
)
168-
)
169-
170-
171-
def local_constants() -> ConstantsT:
172-
constants: set[ConstantT] = set()
173-
for module in local_modules():
174-
tree = _module_ast(module)
175-
if tree is None: # pragma: no cover
176-
continue
177-
constants |= constants_from_ast(tree)
178-
179-
local_constants: ConstantsT = {
180-
"integer": SortedSet(),
181-
"float": SortedSet(key=float_to_int),
182-
"bytes": SortedSet(),
183-
"string": SortedSet(),
184198
}
185-
for value in constants:
186-
choice_type = {
187-
int: "integer",
188-
float: "float",
189-
bytes: "bytes",
190-
str: "string",
191-
}[type(value)]
192-
local_constants[choice_type].add(value) # type: ignore # hard to type
193-
194-
return local_constants

hypothesis-python/tests/conjecture/test_local_constants.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212

1313
import pytest
1414

15-
from hypothesis import settings, strategies as st
15+
from hypothesis import given, settings, strategies as st
16+
from hypothesis.internal import constants_ast
1617
from hypothesis.internal.conjecture import providers
1718
from hypothesis.internal.conjecture.choice import choice_equal
19+
from hypothesis.internal.conjecture.providers import CONSTANTS_CACHE
1820

1921
from tests.common.debug import find_any
2022
from tests.common.utils import Why, xfail_on_crosshair
@@ -24,30 +26,51 @@
2426
# with CONSTANTS_CACHE when testing it inside of a hypothesis test.
2527
@pytest.mark.parametrize("value", [2**20 - 50, 2**10 - 10, 129387123, -19827321, 0])
2628
def test_can_draw_local_constants_integers(monkeypatch, value):
27-
monkeypatch.setattr(providers, "local_constants", lambda: {"integer": {value}})
29+
# _get_local_constants normally invalidates this cache for us, but we're
30+
# monkeypatching it.
31+
CONSTANTS_CACHE.cache.clear()
32+
monkeypatch.setattr(providers, "_get_local_constants", lambda: {"integer": {value}})
2833
find_any(st.integers(), lambda v: choice_equal(v, value))
2934

3035

3136
@xfail_on_crosshair(Why.undiscovered) # I think float_to_int is difficult for crosshair
3237
@pytest.mark.parametrize("value", [1.2938, -1823.0239, 1e999, math.nan])
3338
def test_can_draw_local_constants_floats(monkeypatch, value):
34-
monkeypatch.setattr(providers, "local_constants", lambda: {"float": {value}})
39+
CONSTANTS_CACHE.cache.clear()
40+
monkeypatch.setattr(providers, "_get_local_constants", lambda: {"float": {value}})
3541
find_any(st.floats(), lambda v: choice_equal(v, value))
3642

3743

3844
@pytest.mark.parametrize("value", [b"abdefgh", b"a" * 50])
3945
def test_can_draw_local_constants_bytes(monkeypatch, value):
40-
monkeypatch.setattr(providers, "local_constants", lambda: {"bytes": {value}})
46+
CONSTANTS_CACHE.cache.clear()
47+
monkeypatch.setattr(providers, "_get_local_constants", lambda: {"bytes": {value}})
4148
find_any(st.binary(), lambda v: choice_equal(v, value))
4249

4350

4451
@pytest.mark.parametrize("value", ["abdefgh", "a" * 50])
4552
def test_can_draw_local_constants_string(monkeypatch, value):
46-
monkeypatch.setattr(providers, "local_constants", lambda: {"string": {value}})
53+
CONSTANTS_CACHE.cache.clear()
54+
monkeypatch.setattr(providers, "_get_local_constants", lambda: {"string": {value}})
4755
# we have a bunch of strings in GLOBAL_CONSTANTS, so it might take a while
4856
# to generate our local constant.
4957
find_any(
5058
st.text(),
5159
lambda v: choice_equal(v, value),
5260
settings=settings(max_examples=5_000),
5361
)
62+
63+
64+
def test_actual_collection(monkeypatch):
65+
# covering test for doing some real work collecting constants. We'll fake
66+
# hypothesis as being the "local" module, just to get some real constant
67+
# collection going.
68+
monkeypatch.setattr(
69+
constants_ast, "_is_local_module_file", lambda f: "hypothesis" in f
70+
)
71+
72+
@given(st.integers())
73+
def f(n):
74+
pass
75+
76+
f()

0 commit comments

Comments
 (0)