Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions pint/registry_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,13 @@ def _converter(ureg, values, strict):
return _converter


def _apply_defaults(func, args, kwargs):
def _apply_defaults(sig, args, kwargs):
"""Apply default keyword arguments.

Named keywords may have been left blank. This function applies the default
values so that every argument is defined.
"""

sig = signature(func)
bound_arguments = sig.bind(*args, **kwargs)
for param in sig.parameters.values():
if param.name not in bound_arguments.arguments:
Expand Down Expand Up @@ -254,7 +253,8 @@ def wraps(
ret = _to_units_container(ret, ureg)

def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:
count_params = len(signature(func).parameters)
sig = signature(func)
count_params = len(sig.parameters)
if len(args) != count_params:
raise TypeError(
"%s takes %i parameters, but %i units were passed"
Expand All @@ -270,7 +270,7 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*values, **kw) -> Quantity:
values, kw = _apply_defaults(func, values, kw)
values, kw = _apply_defaults(sig, values, kw)

# In principle, the values are used as is
# When then extract the magnitudes when needed.
Expand Down Expand Up @@ -335,7 +335,8 @@ def check(
]

def decorator(func):
count_params = len(signature(func).parameters)
sig = signature(func)
count_params = len(sig.parameters)
if len(dimensions) != count_params:
raise TypeError(
"%s takes %i parameters, but %i dimensions were passed"
Expand All @@ -351,7 +352,7 @@ def decorator(func):

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*args, **kwargs):
list_args, empty = _apply_defaults(func, args, kwargs)
list_args, empty = _apply_defaults(sig, args, kwargs)

for dim, value in zip(dimensions, list_args):
if dim is None:
Expand Down
36 changes: 36 additions & 0 deletions pint/testsuite/benchmarks/test_20_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,39 @@ def test_op2(benchmark, setup, keys, op):
_, data = setup
key1, key2 = keys
benchmark(op, data[key1], data[key2])


@pytest.mark.parametrize("key", ALL_VALUES_Q)
def test_wrapper(benchmark, setup, key):
ureg, data = setup
value, unit = key.split("_")

@ureg.wraps(None, (unit,))
def f(a):
pass

benchmark(f, data[key])


@pytest.mark.parametrize("key", ALL_VALUES_Q)
def test_wrapper_nonstrict(benchmark, setup, key):
ureg, data = setup
value, unit = key.split("_")

@ureg.wraps(None, (unit,), strict=False)
def f(a):
pass

benchmark(f, data[value])


@pytest.mark.parametrize("key", ALL_VALUES_Q)
def test_wrapper_ret(benchmark, setup, key):
ureg, data = setup
value, unit = key.split("_")

@ureg.wraps(unit, (unit,))
def f(a):
return a

benchmark(f, data[key])