diff --git a/docs/api_core.rst b/docs/api_core.rst index 15b8e6492..58198f1d4 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -1616,7 +1616,9 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`, .. cpp:function:: template arg_v operator=(T &&value) const - Assign a default value to the argument. + Return an argument annotation that is like this one but also assigns a + default value to the argument. The default will be converted into a Python + object immediately, so its bindings must have already been defined. .. cpp:function:: arg &none(bool value = true) @@ -1638,11 +1640,11 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`, explain it in docstrings and stubs (``str(value)``) does not produce acceptable output. - .. cpp:function:: arg &lock(bool value = true) + .. cpp:function:: arg_locked lock() - Set a flag noting that this argument must be locked when dispatching a - function call in free-threaded Python extensions. It does nothing in - regular GIL-protected extensions. + Return an argument annotation that is like this one but also requests that + this argument be locked when dispatching a function call in free-threaded + Python extensions. It does nothing in regular GIL-protected extensions. .. cpp:struct:: is_method diff --git a/include/nanobind/nb_attr.h b/include/nanobind/nb_attr.h index 36a2dfb59..0476824ee 100644 --- a/include/nanobind/nb_attr.h +++ b/include/nanobind/nb_attr.h @@ -20,9 +20,17 @@ struct name { }; struct arg_v; +struct arg_locked; +struct arg_locked_v; + +// Basic function argument descriptor (no default value, not locked) struct arg { NB_INLINE constexpr explicit arg(const char *name = nullptr) : name_(name), signature_(nullptr) { } + + // operator= can be used to provide a default value template NB_INLINE arg_v operator=(T &&value) const; + + // Mutators that don't change default value or locked state NB_INLINE arg &noconvert(bool value = true) { convert_ = !value; return *this; @@ -31,29 +39,75 @@ struct arg { none_ = value; return *this; } - - NB_INLINE arg &lock(bool value = true) { - lock_ = value; - return *this; - } - NB_INLINE arg &sig(const char *value) { signature_ = value; return *this; } + // After lock(), this argument is locked + NB_INLINE arg_locked lock(); + const char *name_, *signature_; uint8_t convert_{ true }; bool none_{ false }; - bool lock_{ false }; }; +// Function argument descriptor with default value (not locked) struct arg_v : arg { object value; NB_INLINE arg_v(const arg &base, object &&value) : arg(base), value(std::move(value)) {} + + private: + // Inherited mutators would slice off the default, and are not generally needed + using arg::noconvert; + using arg::none; + using arg::sig; + using arg::lock; }; +// Function argument descriptor that is locked (no default value) +struct arg_locked : arg { + NB_INLINE constexpr explicit arg_locked(const char *name = nullptr) : arg(name) { } + NB_INLINE constexpr explicit arg_locked(const arg &base) : arg(base) { } + + // operator= can be used to provide a default value + template NB_INLINE arg_locked_v operator=(T &&value) const; + + // Mutators must be respecified in order to not slice off the locked status + NB_INLINE arg_locked &noconvert(bool value = true) { + convert_ = !value; + return *this; + } + NB_INLINE arg_locked &none(bool value = true) { + none_ = value; + return *this; + } + NB_INLINE arg_locked &sig(const char *value) { + signature_ = value; + return *this; + } + + // Redundant extra lock() is allowed + NB_INLINE arg_locked &lock() { return *this; } +}; + +// Function argument descriptor that is potentially locked and has a default value +struct arg_locked_v : arg_locked { + object value; + NB_INLINE arg_locked_v(const arg_locked &base, object &&value) + : arg_locked(base), value(std::move(value)) {} + + private: + // Inherited mutators would slice off the default, and are not generally needed + using arg_locked::noconvert; + using arg_locked::none; + using arg_locked::sig; + using arg_locked::lock; +}; + +NB_INLINE arg_locked arg::lock() { return arg_locked{*this}; } + template struct call_guard { using type = detail::tuple; }; @@ -133,9 +187,7 @@ enum class func_flags : uint32_t { /// Does this overload specify a custom function signature (for docstrings, typing) has_signature = (1 << 16), /// Does this function have one or more nb::keep_alive() annotations? - has_keep_alive = (1 << 17), - /// Free-threaded Python: does the binding lock the 'self' argument - lock_self = (1 << 18) + has_keep_alive = (1 << 17) }; enum cast_flags : uint8_t { @@ -148,14 +200,11 @@ enum cast_flags : uint8_t { // Indicates that the function dispatcher should accept 'None' arguments accepts_none = (1 << 2), - // Indicates that a function argument must be locked before dispatching a call - lock = (1 << 3), - // Indicates that this cast is performed by nb::cast or nb::try_cast. // This implies that objects added to the cleanup list may be // released immediately after the caster's final output value is // obtained, i.e., before it is used. - manual = (1 << 4) + manual = (1 << 3) }; @@ -300,8 +349,6 @@ NB_INLINE void func_extra_apply(F &f, const arg &a, size_t &index) { flag |= (uint8_t) cast_flags::accepts_none; if (a.convert_) flag |= (uint8_t) cast_flags::convert; - if (a.lock_) - flag |= (uint8_t) cast_flags::lock; arg_data &arg = f.args[index]; arg.flag = flag; @@ -310,6 +357,8 @@ NB_INLINE void func_extra_apply(F &f, const arg &a, size_t &index) { arg.value = nullptr; index++; } +// arg_locked will select the arg overload; the locking is added statically +// in nb_func.h template NB_INLINE void func_extra_apply(F &f, const arg_v &a, size_t &index) { @@ -317,14 +366,18 @@ NB_INLINE void func_extra_apply(F &f, const arg_v &a, size_t &index) { func_extra_apply(f, (const arg &) a, index); ad.value = a.value.ptr(); } +template +NB_INLINE void func_extra_apply(F &f, const arg_locked_v &a, size_t &index) { + arg_data &ad = f.args[index]; + func_extra_apply(f, (const arg_locked &) a, index); + ad.value = a.value.ptr(); +} template NB_INLINE void func_extra_apply(F &, kw_only, size_t &) {} template -NB_INLINE void func_extra_apply(F &f, lock_self, size_t &) { - f.flags |= (uint32_t) func_flags::lock_self; -} +NB_INLINE void func_extra_apply(F &, lock_self, size_t &) {} template NB_INLINE void func_extra_apply(F &, call_guard, size_t &) {} @@ -337,6 +390,7 @@ NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive, size template struct func_extra_info { using call_guard = void; static constexpr bool keep_alive = false; + static constexpr size_t nargs_locked = 0; }; template struct func_extra_info @@ -354,6 +408,16 @@ struct func_extra_info, Ts...> : func_extra static constexpr bool keep_alive = true; }; +template +struct func_extra_info : func_extra_info { + static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; +}; + +template +struct func_extra_info : func_extra_info { + static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; +}; + template NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { } diff --git a/include/nanobind/nb_call.h b/include/nanobind/nb_call.h index b20bba9d8..dfbeb45e8 100644 --- a/include/nanobind/nb_call.h +++ b/include/nanobind/nb_call.h @@ -35,6 +35,9 @@ args_proxy api::operator*() const { template NB_INLINE void call_analyze(size_t &nargs, size_t &nkwargs, const T &value) { using D = std::decay_t; + static_assert(!std::is_base_of_v, + "nb::arg().lock() may be used only when defining functions, " + "not when calling them"); if constexpr (std::is_same_v) nkwargs++; diff --git a/include/nanobind/nb_cast.h b/include/nanobind/nb_cast.h index 8c31ece2a..9df888d05 100644 --- a/include/nanobind/nb_cast.h +++ b/include/nanobind/nb_cast.h @@ -633,6 +633,9 @@ tuple make_tuple(Args &&...args) { template arg_v arg::operator=(T &&value) const { return arg_v(*this, cast((detail::forward_t) value)); } +template arg_locked_v arg_locked::operator=(T &&value) const { + return arg_locked_v(*this, cast((detail::forward_t) value)); +} template template detail::accessor& detail::accessor::operator=(T &&value) { diff --git a/include/nanobind/nb_func.h b/include/nanobind/nb_func.h index 226840502..1cd98a1b5 100644 --- a/include/nanobind/nb_func.h +++ b/include/nanobind/nb_func.h @@ -32,9 +32,40 @@ bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags, template constexpr size_t count_args_before_index(std::index_sequence) { static_assert(sizeof...(Is) == sizeof...(Ts)); - return ((Is < I && (std::is_same_v || std::is_same_v)) + ... + 0); + return ((Is < I && std::is_base_of_v) + ... + 0); } +#if defined(NB_FREE_THREADED) +struct ft_args_collector { + PyObject **args; + handle h1; + handle h2; + size_t index = 0; + + NB_INLINE explicit ft_args_collector(PyObject **a) : args(a) {} + NB_INLINE void apply(arg_locked *) { + if (h1.ptr() == nullptr) + h1 = args[index]; + h2 = args[index]; + ++index; + } + NB_INLINE void apply(arg *) { ++index; } + NB_INLINE void apply(...) {} +}; + +struct ft_args_guard { + NB_INLINE void lock(const ft_args_collector& info) { + PyCriticalSection2_Begin(&cs, info.h1.ptr(), info.h2.ptr()); + } + ~ft_args_guard() { + PyCriticalSection2_End(&cs); + } + PyCriticalSection2 cs; +}; +#endif + +struct no_guard {}; + template NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), @@ -66,13 +97,21 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), // Determine the number of nb::arg/nb::arg_v annotations constexpr size_t nargs_provided = - ((std::is_same_v + std::is_same_v) + ... + 0); + (std::is_base_of_v + ... + 0); constexpr bool is_method_det = (std::is_same_v + ... + 0) != 0; constexpr bool is_getter_det = (std::is_same_v + ... + 0) != 0; constexpr bool has_arg_annotations = nargs_provided > 0 && !is_getter_det; + // Determine the number of potentially-locked function arguments + constexpr bool lock_self_det = + (std::is_same_v + ... + 0) != 0; + static_assert(Info::nargs_locked <= 2, + "At most two function arguments can be locked"); + static_assert(!(lock_self_det && !is_method_det), + "The nb::lock_self() annotation only applies to methods"); + // Detect location of nb::kw_only annotation, if supplied. As with args/kwargs // we find the first and last location and later verify they match each other. // Note this is an index in Extra... while args/kwargs_pos_* are indices in @@ -187,6 +226,21 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), tuple...> in; (void) in; +#if defined(NB_FREE_THREADED) + std::conditional_t guard; + if constexpr (Info::nargs_locked) { + ft_args_collector collector{args}; + if constexpr (is_method_det) { + if constexpr (lock_self_det) + collector.apply((arg_locked *) nullptr); + else + collector.apply((arg *) nullptr); + } + (collector.apply((Extra *) nullptr), ...); + guard.lock(collector); + } +#endif + if constexpr (Info::keep_alive) { if ((!from_python_keep_alive(in.template get(), args, args_flags, cleanup, Is) || ...)) diff --git a/src/nb_func.cpp b/src/nb_func.cpp index 682056a99..214a5169e 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -187,50 +187,6 @@ char *strdup_check(const char *s) { return result; } -#if defined(NB_FREE_THREADED) -// Locked function wrapper for free-threaded Python -struct locked_func { - void *capture[3]; - size_t nargs; - PyObject *(*impl)(void *, PyObject **, uint8_t *, rv_policy, - cleanup_list *); - void (*free_capture)(void *); -}; - -static void locked_func_free(void *p) { - locked_func *f = (locked_func *) ((void **) p)[0]; - if (f->free_capture) - f->free_capture(f->capture); - PyMem_Free(f); -} - -static PyObject *locked_func_impl(void *p, PyObject **args, uint8_t *args_flags, - rv_policy policy, cleanup_list *cleanup) { - handle h1, h2; - size_t ctr = 0; - - locked_func *f = (locked_func *) ((void **) p)[0]; - - for (size_t i = 0; i < f->nargs; ++i) { - if (args_flags[i] & (uint8_t) cast_flags::lock) { - if (ctr == 0) - h1 = args[i]; - h2 = args[i]; - ctr++; - } - } - -#if !defined(NDEBUG) - // nb_func_new ensured that at most two arguments are locked, but - // can't hurt to check once more in debug builds - check(ctr == 1 || ctr == 2, "locked_call: expected 1 or 2 locked arguments!"); -#endif - - ft_object2_guard guard(h1, h2); - return f->impl(f->capture, args, args_flags, policy, cleanup); -} -#endif - /** * \brief Wrap a C++ function into a Python function object * @@ -249,7 +205,6 @@ PyObject *nb_func_new(const void *in_) noexcept { is_implicit = f->flags & (uint32_t) func_flags::is_implicit, is_method = f->flags & (uint32_t) func_flags::is_method, return_ref = f->flags & (uint32_t) func_flags::return_ref, - lock_self = f->flags & (uint32_t) func_flags::lock_self, is_constructor = false, is_init = false, is_new = false, @@ -423,7 +378,6 @@ PyObject *nb_func_new(const void *in_) noexcept { } } - size_t lock_count = 0; if (has_args) { fc->args = (arg_data *) malloc_check(sizeof(arg_data) * f->nargs); @@ -442,55 +396,11 @@ PyObject *nb_func_new(const void *in_) noexcept { } if (a.value == Py_None) a.flag |= (uint8_t) cast_flags::accepts_none; - if (a.flag & (uint8_t) cast_flags::lock) - lock_count++; a.signature = a.signature ? strdup_check(a.signature) : nullptr; Py_XINCREF(a.value); } } - if (lock_self) { - check(is_method && !is_init, - "nb::detail::nb_func_new(\"%s\"): the nb::lock_self annotation only " - "applies to regular methods.", name_cstr); - -#if defined(NB_FREE_THREADED) - // Must potentially allocate dummy 'args' if 'lock_self' is set - if (!has_args) { - fc->args = (arg_data *) malloc_check(sizeof(arg_data) * f->nargs); - memset(fc->args, 0, sizeof(arg_data) * f->nargs); - for (uint32_t i = 1; i < f->nargs; ++i) - fc->args[i].flag &= (uint8_t) cast_flags::convert; - func->vectorcall = nb_func_vectorcall_complex; - fc->flags |= (uint32_t) func_flags::has_args; - has_args = true; - } - - fc->args[0].flag |= (uint8_t) cast_flags::lock; -#endif - - lock_count++; - } - - check(lock_count <= 2, - "nb::detail::nb_func_new(\"%s\"): at most two function arguments can " - "be locked.", name_cstr); - -#if defined(NB_FREE_THREADED) - if (lock_count) { - locked_func *lf = (locked_func *) PyMem_Malloc(sizeof(locked_func)); - check(lf, "nb::detail::nb_func_new(\"%s\"): locked function alloc. failed.", name_cstr); - memcpy(lf->capture, fc->capture, sizeof(func_data::capture)); - lf->nargs = fc->nargs; - lf->impl = fc->impl; - lf->free_capture = (fc->flags & (uint32_t) func_flags::has_free) ? fc->free_capture : nullptr; - fc->impl = locked_func_impl; - fc->free_capture = locked_func_free; - fc->flags |= (uint32_t) func_flags::has_free; - fc->capture[0] = lf; - } -#endif - // Fast path for vector call object construction if (((is_init && is_method) || (is_new && !is_method)) && nb_type_check(f->scope)) {