|
1 | 1 | import hashlib
|
2 | 2 | import inspect
|
3 | 3 | import json
|
4 |
| -from functools import lru_cache, partial, wraps |
| 4 | +from dataclasses import is_dataclass |
| 5 | +from enum import Enum |
| 6 | +from functools import partial, wraps |
5 | 7 | from typing import (
|
6 | 8 | Any,
|
7 | 9 | Callable,
|
@@ -378,17 +380,51 @@ def register_event_handler(
|
378 | 380 | return fn_id
|
379 | 381 |
|
380 | 382 |
|
381 |
| -@lru_cache(maxsize=None) |
| 383 | +def has_stable_repr(obj: Any) -> bool: |
| 384 | + """Check if an object has a stable repr.""" |
| 385 | + stable_types = (int, float, str, bool, type(None), tuple, frozenset, Enum) # type: ignore |
| 386 | + |
| 387 | + if isinstance(obj, stable_types): |
| 388 | + return True |
| 389 | + if is_dataclass(obj): |
| 390 | + return all( |
| 391 | + has_stable_repr(getattr(obj, f.name)) |
| 392 | + for f in obj.__dataclass_fields__.values() |
| 393 | + ) |
| 394 | + if isinstance(obj, (list, set)): |
| 395 | + return all(has_stable_repr(item) for item in obj) # type: ignore |
| 396 | + if isinstance(obj, dict): |
| 397 | + return all( |
| 398 | + has_stable_repr(k) and has_stable_repr(v) |
| 399 | + for k, v in obj.items() # type: ignore |
| 400 | + ) |
| 401 | + |
| 402 | + return False |
| 403 | + |
| 404 | + |
382 | 405 | def compute_fn_id(fn: Callable[..., Any]) -> str:
|
383 | 406 | if isinstance(fn, partial):
|
384 |
| - # Include the partially applied arguments in the source code |
| 407 | + func_source = inspect.getsource(fn.func) |
| 408 | + |
| 409 | + for arg in fn.args: |
| 410 | + if not has_stable_repr(arg): |
| 411 | + raise MesopDeveloperException( |
| 412 | + f"Argument {arg} for functools.partial event handler {fn.func.__name__} does not have a stable repr" |
| 413 | + ) |
| 414 | + |
| 415 | + for k, v in fn.keywords.items(): |
| 416 | + if not has_stable_repr(v): |
| 417 | + raise MesopDeveloperException( |
| 418 | + f"Keyword argument {k}={v} for functools.partial event handler {fn.func.__name__} does not have a stable repr" |
| 419 | + ) |
| 420 | + |
385 | 421 | args_str = ", ".join(repr(arg) for arg in fn.args)
|
386 | 422 | kwargs_str = ", ".join(f"{k}={v!r}" for k, v in fn.keywords.items())
|
387 | 423 | partial_args = (
|
388 | 424 | f"{args_str}{', ' if args_str and kwargs_str else ''}{kwargs_str}"
|
389 | 425 | )
|
390 |
| - source_code = f"partial(<<{inspect.getsource(fn.func)}>>, {partial_args})" |
391 |
| - print("source_code", source_code) |
| 426 | + |
| 427 | + source_code = f"partial(<<{func_source}>>, {partial_args})" |
392 | 428 | fn_name = fn.func.__name__
|
393 | 429 | fn_module = fn.func.__module__
|
394 | 430 | else:
|
|
0 commit comments