Skip to content

Commit 2dfc770

Browse files
authored
Support functools.partial for event handlers (#789)
1 parent ef9d8ba commit 2dfc770

File tree

4 files changed

+105
-6
lines changed

4 files changed

+105
-6
lines changed

mesop/component_helpers/helper.py

+65-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import hashlib
22
import inspect
33
import json
4-
from functools import lru_cache, wraps
4+
from dataclasses import is_dataclass
5+
from enum import Enum
6+
from functools import lru_cache, partial, wraps
57
from typing import (
68
Any,
79
Callable,
@@ -357,8 +359,12 @@ def wrapper(action: E):
357359

358360
return func(cast(Any, event))
359361

360-
wrapper.__module__ = func.__module__
361-
wrapper.__name__ = func.__name__
362+
if isinstance(func, partial):
363+
wrapper.__module__ = func.func.__module__
364+
wrapper.__name__ = func.func.__name__
365+
else:
366+
wrapper.__module__ = func.__module__
367+
wrapper.__name__ = func.__name__
362368

363369
return wrapper
364370

@@ -374,14 +380,67 @@ def register_event_handler(
374380
return fn_id
375381

376382

383+
def has_stable_repr(obj: Any) -> bool:
384+
"""Check if an object has a stable repr.
385+
We need to ensure that the repr is stable between different Python runtimes.
386+
"""
387+
stable_types = (int, float, str, bool, type(None), tuple, frozenset, Enum) # type: ignore
388+
389+
if isinstance(obj, stable_types):
390+
return True
391+
if is_dataclass(obj):
392+
return all(
393+
has_stable_repr(getattr(obj, f.name))
394+
for f in obj.__dataclass_fields__.values()
395+
)
396+
if isinstance(obj, (list, set)):
397+
return all(has_stable_repr(item) for item in obj) # type: ignore
398+
if isinstance(obj, dict):
399+
return all(
400+
has_stable_repr(k) and has_stable_repr(v)
401+
for k, v in obj.items() # type: ignore
402+
)
403+
404+
return False
405+
406+
377407
@lru_cache(maxsize=None)
378408
def compute_fn_id(fn: Callable[..., Any]) -> str:
379-
source_code = inspect.getsource(fn)
409+
if isinstance(fn, partial):
410+
func_source = inspect.getsource(fn.func)
411+
# For partial functions, we need to ensure that the arguments have a stable repr
412+
# because we use the repr to compute the fn_id.
413+
for arg in fn.args:
414+
if not has_stable_repr(arg):
415+
raise MesopDeveloperException(
416+
f"Argument {arg} for functools.partial event handler {fn.func.__name__} does not have a stable repr"
417+
)
418+
419+
for k, v in fn.keywords.items():
420+
if not has_stable_repr(v):
421+
raise MesopDeveloperException(
422+
f"Keyword argument {k}={v} for functools.partial event handler {fn.func.__name__} does not have a stable repr"
423+
)
424+
425+
args_str = ", ".join(repr(arg) for arg in fn.args)
426+
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in fn.keywords.items())
427+
partial_args = (
428+
f"{args_str}{', ' if args_str and kwargs_str else ''}{kwargs_str}"
429+
)
430+
431+
source_code = f"partial(<<{func_source}>>, {partial_args})"
432+
fn_name = fn.func.__name__
433+
fn_module = fn.func.__module__
434+
else:
435+
source_code = inspect.getsource(fn) if inspect.isfunction(fn) else str(fn)
436+
fn_name = fn.__name__
437+
fn_module = fn.__module__
438+
380439
# Skip hashing the fn/module name in debug mode because it makes it hard to debug.
381440
if runtime().debug_mode:
382441
source_code_hash = hashlib.sha256(source_code.encode()).hexdigest()
383-
return f"{fn.__module__}.{fn.__name__}.{source_code_hash}"
384-
input = f"{fn.__module__}.{fn.__name__}.{source_code}"
442+
return f"{fn_module}.{fn_name}.{source_code_hash}"
443+
input = f"{fn_module}.{fn_name}.{source_code}"
385444
return hashlib.sha256(input.encode()).hexdigest()
386445

387446

mesop/examples/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from mesop.examples import event_handler_error as event_handler_error
2323
from mesop.examples import focus_component as focus_component
24+
from mesop.examples import functools_partial as functools_partial
2425
from mesop.examples import generator as generator
2526
from mesop.examples import grid as grid
2627
from mesop.examples import index as index

mesop/examples/functools_partial.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from functools import partial
2+
3+
import mesop as me
4+
5+
6+
@me.stateclass
7+
class State:
8+
count: int = 0
9+
10+
11+
@me.page(path="/functools_partial")
12+
def main():
13+
state = me.state(State)
14+
me.text(text=f"value={state.count}")
15+
me.button("increment 2*4", on_click=partial(increment_click, 2, amount=4))
16+
me.button("increment 2*10", on_click=partial(increment_click, 2, amount=10))
17+
for i in range(10):
18+
me.button(f"decrement {i}", on_click=partial(decrement_click, i))
19+
20+
21+
def increment_click(multiplier: int, action: me.ClickEvent, amount: int):
22+
state = me.state(State)
23+
state.count += multiplier * amount
24+
25+
26+
def decrement_click(amount: int, action: me.ClickEvent):
27+
state = me.state(State)
28+
state.count -= amount
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import {test, expect} from '@playwright/test';
2+
3+
test('functools partial', async ({page}) => {
4+
await page.goto('/functools_partial');
5+
6+
await page.getByRole('button', {name: 'increment 2*4'}).click();
7+
await expect(page.getByText('value=8')).toBeVisible();
8+
9+
await page.getByRole('button', {name: 'increment 2*10'}).click();
10+
await expect(page.getByText('value=28')).toBeVisible();
11+
});

0 commit comments

Comments
 (0)