Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,9 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`,

.. cpp:function:: template <typename T> arg_v operator=(T &&value) const

Assign a default value to the argument.
Return an argument annotation that is like this one but also assigns a
default value to the argument. The default will be converted into a Python
object immediately, so its bindings must have already been defined.

.. cpp:function:: arg &none(bool value = true)

Expand All @@ -1638,11 +1640,11 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`,
explain it in docstrings and stubs (``str(value)``) does not produce
acceptable output.

.. cpp:function:: arg &lock(bool value = true)
.. cpp:function:: arg_locked lock()

Set a flag noting that this argument must be locked when dispatching a
function call in free-threaded Python extensions. It does nothing in
regular GIL-protected extensions.
Return an argument annotation that is like this one but also requests that
this argument be locked when dispatching a function call in free-threaded
Python extensions. It does nothing in regular GIL-protected extensions.

.. cpp:struct:: is_method

Expand Down
102 changes: 83 additions & 19 deletions include/nanobind/nb_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ struct name {
};

struct arg_v;
struct arg_locked;
struct arg_locked_v;

// Basic function argument descriptor (no default value, not locked)
struct arg {
NB_INLINE constexpr explicit arg(const char *name = nullptr) : name_(name), signature_(nullptr) { }

// operator= can be used to provide a default value
template <typename T> NB_INLINE arg_v operator=(T &&value) const;

// Mutators that don't change default value or locked state
NB_INLINE arg &noconvert(bool value = true) {
convert_ = !value;
return *this;
Expand All @@ -31,29 +39,75 @@ struct arg {
none_ = value;
return *this;
}

NB_INLINE arg &lock(bool value = true) {
lock_ = value;
return *this;
}

NB_INLINE arg &sig(const char *value) {
signature_ = value;
return *this;
}

// After lock(), this argument is locked
NB_INLINE arg_locked lock();

const char *name_, *signature_;
uint8_t convert_{ true };
bool none_{ false };
bool lock_{ false };
};

// Function argument descriptor with default value (not locked)
struct arg_v : arg {
object value;
NB_INLINE arg_v(const arg &base, object &&value)
: arg(base), value(std::move(value)) {}

private:
// Inherited mutators would slice off the default, and are not generally needed
using arg::noconvert;
using arg::none;
using arg::sig;
using arg::lock;
};

// Function argument descriptor that is locked (no default value)
struct arg_locked : arg {
NB_INLINE constexpr explicit arg_locked(const char *name = nullptr) : arg(name) { }
NB_INLINE constexpr explicit arg_locked(const arg &base) : arg(base) { }

// operator= can be used to provide a default value
template <typename T> NB_INLINE arg_locked_v operator=(T &&value) const;

// Mutators must be respecified in order to not slice off the locked status
NB_INLINE arg_locked &noconvert(bool value = true) {
convert_ = !value;
return *this;
}
NB_INLINE arg_locked &none(bool value = true) {
none_ = value;
return *this;
}
NB_INLINE arg_locked &sig(const char *value) {
signature_ = value;
return *this;
}

// Redundant extra lock() is allowed
NB_INLINE arg_locked &lock() { return *this; }
};

// Function argument descriptor that is potentially locked and has a default value
struct arg_locked_v : arg_locked {
object value;
NB_INLINE arg_locked_v(const arg_locked &base, object &&value)
: arg_locked(base), value(std::move(value)) {}

private:
// Inherited mutators would slice off the default, and are not generally needed
using arg_locked::noconvert;
using arg_locked::none;
using arg_locked::sig;
using arg_locked::lock;
};

NB_INLINE arg_locked arg::lock() { return arg_locked{*this}; }

template <typename... Ts> struct call_guard {
using type = detail::tuple<Ts...>;
};
Expand Down Expand Up @@ -133,9 +187,7 @@ enum class func_flags : uint32_t {
/// Does this overload specify a custom function signature (for docstrings, typing)
has_signature = (1 << 16),
/// Does this function have one or more nb::keep_alive() annotations?
has_keep_alive = (1 << 17),
/// Free-threaded Python: does the binding lock the 'self' argument
lock_self = (1 << 18)
has_keep_alive = (1 << 17)
};

enum cast_flags : uint8_t {
Expand All @@ -148,14 +200,11 @@ enum cast_flags : uint8_t {
// Indicates that the function dispatcher should accept 'None' arguments
accepts_none = (1 << 2),

// Indicates that a function argument must be locked before dispatching a call
lock = (1 << 3),

// Indicates that this cast is performed by nb::cast or nb::try_cast.
// This implies that objects added to the cleanup list may be
// released immediately after the caster's final output value is
// obtained, i.e., before it is used.
manual = (1 << 4)
manual = (1 << 3)
};


Expand Down Expand Up @@ -300,8 +349,6 @@ NB_INLINE void func_extra_apply(F &f, const arg &a, size_t &index) {
flag |= (uint8_t) cast_flags::accepts_none;
if (a.convert_)
flag |= (uint8_t) cast_flags::convert;
if (a.lock_)
flag |= (uint8_t) cast_flags::lock;

arg_data &arg = f.args[index];
arg.flag = flag;
Expand All @@ -310,21 +357,27 @@ NB_INLINE void func_extra_apply(F &f, const arg &a, size_t &index) {
arg.value = nullptr;
index++;
}
// arg_locked will select the arg overload; the locking is added statically
// in nb_func.h

template <typename F>
NB_INLINE void func_extra_apply(F &f, const arg_v &a, size_t &index) {
arg_data &ad = f.args[index];
func_extra_apply(f, (const arg &) a, index);
ad.value = a.value.ptr();
}
template <typename F>
NB_INLINE void func_extra_apply(F &f, const arg_locked_v &a, size_t &index) {
arg_data &ad = f.args[index];
func_extra_apply(f, (const arg_locked &) a, index);
ad.value = a.value.ptr();
}

template <typename F>
NB_INLINE void func_extra_apply(F &, kw_only, size_t &) {}

template <typename F>
NB_INLINE void func_extra_apply(F &f, lock_self, size_t &) {
f.flags |= (uint32_t) func_flags::lock_self;
}
NB_INLINE void func_extra_apply(F &, lock_self, size_t &) {}

template <typename F, typename... Ts>
NB_INLINE void func_extra_apply(F &, call_guard<Ts...>, size_t &) {}
Expand All @@ -337,6 +390,7 @@ NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive<Nurse, Patient>, size
template <typename... Ts> struct func_extra_info {
using call_guard = void;
static constexpr bool keep_alive = false;
static constexpr size_t nargs_locked = 0;
};

template <typename T, typename... Ts> struct func_extra_info<T, Ts...>
Expand All @@ -354,6 +408,16 @@ struct func_extra_info<nanobind::keep_alive<Nurse, Patient>, Ts...> : func_extra
static constexpr bool keep_alive = true;
};

template <typename... Ts>
struct func_extra_info<nanobind::arg_locked, Ts...> : func_extra_info<Ts...> {
static constexpr size_t nargs_locked = 1 + func_extra_info<Ts...>::nargs_locked;
};

template <typename... Ts>
struct func_extra_info<nanobind::lock_self, Ts...> : func_extra_info<Ts...> {
static constexpr size_t nargs_locked = 1 + func_extra_info<Ts...>::nargs_locked;
};

template <typename T>
NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { }

Expand Down
3 changes: 3 additions & 0 deletions include/nanobind/nb_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ args_proxy api<Derived>::operator*() const {
template <typename T>
NB_INLINE void call_analyze(size_t &nargs, size_t &nkwargs, const T &value) {
using D = std::decay_t<T>;
static_assert(!std::is_base_of_v<arg_locked, D>,
"nb::arg().lock() may be used only when defining functions, "
"not when calling them");

if constexpr (std::is_same_v<D, arg_v>)
nkwargs++;
Expand Down
3 changes: 3 additions & 0 deletions include/nanobind/nb_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,9 @@ tuple make_tuple(Args &&...args) {
template <typename T> arg_v arg::operator=(T &&value) const {
return arg_v(*this, cast((detail::forward_t<T>) value));
}
template <typename T> arg_locked_v arg_locked::operator=(T &&value) const {
return arg_locked_v(*this, cast((detail::forward_t<T>) value));
}

template <typename Impl> template <typename T>
detail::accessor<Impl>& detail::accessor<Impl>::operator=(T &&value) {
Expand Down
58 changes: 56 additions & 2 deletions include/nanobind/nb_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,40 @@ bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags,
template <size_t I, typename... Ts, size_t... Is>
constexpr size_t count_args_before_index(std::index_sequence<Is...>) {
static_assert(sizeof...(Is) == sizeof...(Ts));
return ((Is < I && (std::is_same_v<arg, Ts> || std::is_same_v<arg_v, Ts>)) + ... + 0);
return ((Is < I && std::is_base_of_v<arg, Ts>) + ... + 0);
}

#if defined(NB_FREE_THREADED)
struct ft_args_collector {
PyObject **args;
handle h1;
handle h2;
size_t index = 0;

NB_INLINE explicit ft_args_collector(PyObject **a) : args(a) {}
NB_INLINE void apply(arg_locked *) {
if (h1.ptr() == nullptr)
h1 = args[index];
h2 = args[index];
++index;
}
NB_INLINE void apply(arg *) { ++index; }
NB_INLINE void apply(...) {}
};

struct ft_args_guard {
NB_INLINE void lock(const ft_args_collector& info) {
PyCriticalSection2_Begin(&cs, info.h1.ptr(), info.h2.ptr());
}
~ft_args_guard() {
PyCriticalSection2_End(&cs);
}
PyCriticalSection2 cs;
};
#endif

struct no_guard {};

template <bool ReturnRef, bool CheckGuard, typename Func, typename Return,
typename... Args, size_t... Is, typename... Extra>
NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
Expand Down Expand Up @@ -66,13 +97,21 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),

// Determine the number of nb::arg/nb::arg_v annotations
constexpr size_t nargs_provided =
((std::is_same_v<arg, Extra> + std::is_same_v<arg_v, Extra>) + ... + 0);
(std::is_base_of_v<arg, Extra> + ... + 0);
constexpr bool is_method_det =
(std::is_same_v<is_method, Extra> + ... + 0) != 0;
constexpr bool is_getter_det =
(std::is_same_v<is_getter, Extra> + ... + 0) != 0;
constexpr bool has_arg_annotations = nargs_provided > 0 && !is_getter_det;

// Determine the number of potentially-locked function arguments
constexpr bool lock_self_det =
(std::is_same_v<lock_self, Extra> + ... + 0) != 0;
static_assert(Info::nargs_locked <= 2,
"At most two function arguments can be locked");
static_assert(!(lock_self_det && !is_method_det),
"The nb::lock_self() annotation only applies to methods");

// Detect location of nb::kw_only annotation, if supplied. As with args/kwargs
// we find the first and last location and later verify they match each other.
// Note this is an index in Extra... while args/kwargs_pos_* are indices in
Expand Down Expand Up @@ -187,6 +226,21 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
tuple<make_caster<Args>...> in;
(void) in;

#if defined(NB_FREE_THREADED)
std::conditional_t<Info::nargs_locked != 0, ft_args_guard, no_guard> guard;
if constexpr (Info::nargs_locked) {
ft_args_collector collector{args};
if constexpr (is_method_det) {
if constexpr (lock_self_det)
collector.apply((arg_locked *) nullptr);
else
collector.apply((arg *) nullptr);
}
(collector.apply((Extra *) nullptr), ...);
guard.lock(collector);
}
#endif

if constexpr (Info::keep_alive) {
if ((!from_python_keep_alive(in.template get<Is>(), args,
args_flags, cleanup, Is) || ...))
Expand Down
Loading