|
7 | 7 | from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ
|
8 | 8 | import pickle
|
9 | 9 | import collections.abc
|
| 10 | +import functools |
| 11 | +import contextlib |
| 12 | +import builtins |
10 | 13 |
|
11 | 14 | # Test result of triple loop (too big to inline)
|
12 | 15 | TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
|
@@ -81,6 +84,12 @@ class BadIterableClass:
|
81 | 84 | def __iter__(self):
|
82 | 85 | raise ZeroDivisionError
|
83 | 86 |
|
| 87 | +class EmptyIterClass: |
| 88 | + def __len__(self): |
| 89 | + return 0 |
| 90 | + def __getitem__(self, i): |
| 91 | + raise StopIteration |
| 92 | + |
84 | 93 | # Main test suite
|
85 | 94 |
|
86 | 95 | class TestCase(unittest.TestCase):
|
@@ -228,6 +237,78 @@ def test_mutating_seq_class_exhausted_iter(self):
|
228 | 237 | self.assertEqual(list(empit), [5, 6])
|
229 | 238 | self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
|
230 | 239 |
|
| 240 | + def test_reduce_mutating_builtins_iter(self): |
| 241 | + # This is a reproducer of issue #101765 |
| 242 | + # where iter `__reduce__` calls could lead to a segfault or SystemError |
| 243 | + # depending on the order of C argument evaluation, which is undefined |
| 244 | + |
| 245 | + # Backup builtins |
| 246 | + builtins_dict = builtins.__dict__ |
| 247 | + orig = {"iter": iter, "reversed": reversed} |
| 248 | + |
| 249 | + def run(builtin_name, item, sentinel=None): |
| 250 | + it = iter(item) if sentinel is None else iter(item, sentinel) |
| 251 | + |
| 252 | + class CustomStr: |
| 253 | + def __init__(self, name, iterator): |
| 254 | + self.name = name |
| 255 | + self.iterator = iterator |
| 256 | + def __hash__(self): |
| 257 | + return hash(self.name) |
| 258 | + def __eq__(self, other): |
| 259 | + # Here we exhaust our iterator, possibly changing |
| 260 | + # its `it_seq` pointer to NULL |
| 261 | + # The `__reduce__` call should correctly get |
| 262 | + # the pointers after this call |
| 263 | + list(self.iterator) |
| 264 | + return other == self.name |
| 265 | + |
| 266 | + # del is required here |
| 267 | + # to not prematurely call __eq__ from |
| 268 | + # the hash collision with the old key |
| 269 | + del builtins_dict[builtin_name] |
| 270 | + builtins_dict[CustomStr(builtin_name, it)] = orig[builtin_name] |
| 271 | + |
| 272 | + return it.__reduce__() |
| 273 | + |
| 274 | + types = [ |
| 275 | + (EmptyIterClass(),), |
| 276 | + (bytes(8),), |
| 277 | + (bytearray(8),), |
| 278 | + ((1, 2, 3),), |
| 279 | + (lambda: 0, 0), |
| 280 | + (tuple[int],) # GenericAlias |
| 281 | + ] |
| 282 | + |
| 283 | + try: |
| 284 | + run_iter = functools.partial(run, "iter") |
| 285 | + # The returned value of `__reduce__` should not only be valid |
| 286 | + # but also *empty*, as `it` was exhausted during `__eq__` |
| 287 | + # i.e "xyz" returns (iter, ("",)) |
| 288 | + self.assertEqual(run_iter("xyz"), (orig["iter"], ("",))) |
| 289 | + self.assertEqual(run_iter([1, 2, 3]), (orig["iter"], ([],))) |
| 290 | + |
| 291 | + # _PyEval_GetBuiltin is also called for `reversed` in a branch of |
| 292 | + # listiter_reduce_general |
| 293 | + self.assertEqual( |
| 294 | + run("reversed", orig["reversed"](list(range(8)))), |
| 295 | + (iter, ([],)) |
| 296 | + ) |
| 297 | + |
| 298 | + for case in types: |
| 299 | + self.assertEqual(run_iter(*case), (orig["iter"], ((),))) |
| 300 | + finally: |
| 301 | + # Restore original builtins |
| 302 | + for key, func in orig.items(): |
| 303 | + # need to suppress KeyErrors in case |
| 304 | + # a failed test deletes the key without setting anything |
| 305 | + with contextlib.suppress(KeyError): |
| 306 | + # del is required here |
| 307 | + # to not invoke our custom __eq__ from |
| 308 | + # the hash collision with the old key |
| 309 | + del builtins_dict[key] |
| 310 | + builtins_dict[key] = func |
| 311 | + |
231 | 312 | # Test a new_style class with __iter__ but no next() method
|
232 | 313 | def test_new_style_iter_class(self):
|
233 | 314 | class IterClass(object):
|
|
0 commit comments