1
1
import hashlib
2
2
import inspect
3
3
import json
4
- from functools import lru_cache , wraps
4
+ from dataclasses import is_dataclass
5
+ from enum import Enum
6
+ from functools import lru_cache , partial , wraps
5
7
from typing import (
6
8
Any ,
7
9
Callable ,
@@ -357,8 +359,12 @@ def wrapper(action: E):
357
359
358
360
return func (cast (Any , event ))
359
361
360
- wrapper .__module__ = func .__module__
361
- wrapper .__name__ = func .__name__
362
+ if isinstance (func , partial ):
363
+ wrapper .__module__ = func .func .__module__
364
+ wrapper .__name__ = func .func .__name__
365
+ else :
366
+ wrapper .__module__ = func .__module__
367
+ wrapper .__name__ = func .__name__
362
368
363
369
return wrapper
364
370
@@ -374,14 +380,67 @@ def register_event_handler(
374
380
return fn_id
375
381
376
382
383
+ def has_stable_repr (obj : Any ) -> bool :
384
+ """Check if an object has a stable repr.
385
+ We need to ensure that the repr is stable between different Python runtimes.
386
+ """
387
+ stable_types = (int , float , str , bool , type (None ), tuple , frozenset , Enum ) # type: ignore
388
+
389
+ if isinstance (obj , stable_types ):
390
+ return True
391
+ if is_dataclass (obj ):
392
+ return all (
393
+ has_stable_repr (getattr (obj , f .name ))
394
+ for f in obj .__dataclass_fields__ .values ()
395
+ )
396
+ if isinstance (obj , (list , set )):
397
+ return all (has_stable_repr (item ) for item in obj ) # type: ignore
398
+ if isinstance (obj , dict ):
399
+ return all (
400
+ has_stable_repr (k ) and has_stable_repr (v )
401
+ for k , v in obj .items () # type: ignore
402
+ )
403
+
404
+ return False
405
+
406
+
377
407
@lru_cache (maxsize = None )
378
408
def compute_fn_id (fn : Callable [..., Any ]) -> str :
379
- source_code = inspect .getsource (fn )
409
+ if isinstance (fn , partial ):
410
+ func_source = inspect .getsource (fn .func )
411
+ # For partial functions, we need to ensure that the arguments have a stable repr
412
+ # because we use the repr to compute the fn_id.
413
+ for arg in fn .args :
414
+ if not has_stable_repr (arg ):
415
+ raise MesopDeveloperException (
416
+ f"Argument { arg } for functools.partial event handler { fn .func .__name__ } does not have a stable repr"
417
+ )
418
+
419
+ for k , v in fn .keywords .items ():
420
+ if not has_stable_repr (v ):
421
+ raise MesopDeveloperException (
422
+ f"Keyword argument { k } ={ v } for functools.partial event handler { fn .func .__name__ } does not have a stable repr"
423
+ )
424
+
425
+ args_str = ", " .join (repr (arg ) for arg in fn .args )
426
+ kwargs_str = ", " .join (f"{ k } ={ v !r} " for k , v in fn .keywords .items ())
427
+ partial_args = (
428
+ f"{ args_str } { ', ' if args_str and kwargs_str else '' } { kwargs_str } "
429
+ )
430
+
431
+ source_code = f"partial(<<{ func_source } >>, { partial_args } )"
432
+ fn_name = fn .func .__name__
433
+ fn_module = fn .func .__module__
434
+ else :
435
+ source_code = inspect .getsource (fn ) if inspect .isfunction (fn ) else str (fn )
436
+ fn_name = fn .__name__
437
+ fn_module = fn .__module__
438
+
380
439
# Skip hashing the fn/module name in debug mode because it makes it hard to debug.
381
440
if runtime ().debug_mode :
382
441
source_code_hash = hashlib .sha256 (source_code .encode ()).hexdigest ()
383
- return f"{ fn . __module__ } .{ fn . __name__ } .{ source_code_hash } "
384
- input = f"{ fn . __module__ } .{ fn . __name__ } .{ source_code } "
442
+ return f"{ fn_module } .{ fn_name } .{ source_code_hash } "
443
+ input = f"{ fn_module } .{ fn_name } .{ source_code } "
385
444
return hashlib .sha256 (input .encode ()).hexdigest ()
386
445
387
446
0 commit comments