Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 83 additions & 48 deletions gin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def exit_scope(self):

def _find_class_construction_fn(cls):
"""Find the first __init__ or __new__ method in the given class's MRO."""
for base in type.mro(cls): # pytype: disable=wrong-arg-types
for base in cls.mro(): # pytype: disable=attribute-error
if '__init__' in base.__dict__:
return base.__init__
if '__new__' in base.__dict__:
Expand All @@ -216,54 +216,84 @@ def _ensure_wrappability(fn):
return fn


def _decorate_fn_or_cls(decorator, fn_or_cls, subclass=False):
def _decorate_fn_or_cls(decorator, fn_or_cls, avoid_class_mutation=False):
"""Decorate a function or class with the given decorator.

When `fn_or_cls` is a function, applies `decorator` to the function and
returns the (decorated) result.

When `fn_or_cls` is a class and the `subclass` parameter is `False`, this will
replace `fn_or_cls.__init__` with the result of applying `decorator` to it.

When `fn_or_cls` is a class and `subclass` is `True`, this will subclass the
class, but with `__init__` defined to be the result of applying `decorator` to
`fn_or_cls.__init__`. The decorated class has metadata (docstring, name, and
module information) copied over from `fn_or_cls`. The goal is to provide a
decorated class the behaves as much like the original as possible, without
modifying it (for example, inspection operations using `isinstance` or
`issubclass` should behave the same way as on the original class).
When `fn_or_cls` is a class and the `avoid_class_mutation` parameter is
`False`, this will replace either `fn_or_cls.__init__` or `fn_or_cls.__new__`
(whichever is first implement in the class's MRO, with a preference for
`__init__`) with the result of applying `decorator` to it.

When `fn_or_cls` is a class and `avoid_class_mutation` is `True`, this will
dynamically construct a subclass of `fn_or_cls` using a dynamically
constructed metaclass (which itself is a subclass of `fn_or_cls`'s metaclass).
The metaclass's `__call__` method is wrapped using `decorator` to
intercept/inject paramaters for class construction. The resulting subclass has
metadata (docstring, name, and module information) copied over from
`fn_or_cls`, and should behave like the original as much possible, without
modifying it (for example, inspection operations using `issubclass` should
behave the same way as on the original class). When constructed, an instance
of the original (undecorated) class is returned.

Args:
decorator: The decorator to use.
fn_or_cls: The function or class to decorate.
subclass: Whether to decorate classes by subclassing. This argument is
ignored if `fn_or_cls` is not a class.
avoid_class_mutation: Whether to avoid class mutation using dynamic
subclassing. This argument is ignored if `fn_or_cls` is not a class.

Returns:
The decorated function or class.
"""
if not inspect.isclass(fn_or_cls): # pytype: disable=wrong-arg-types
return decorator(_ensure_wrappability(fn_or_cls))

construction_fn = _find_class_construction_fn(fn_or_cls)

if subclass:

class DecoratedClass(fn_or_cls):
__doc__ = fn_or_cls.__doc__
__module__ = fn_or_cls.__module__

DecoratedClass.__name__ = fn_or_cls.__name__
DecoratedClass.__qualname__ = fn_or_cls.__qualname__
cls = DecoratedClass
cls = fn_or_cls
if avoid_class_mutation:
# This approach enables @gin.register and gin.external_configurable(), and
# is compatible with pickling instances. However, we can't use it for
# @gin.configurable because the decorated class returned below interacts
# poorly with super() calls when subclassed. In general, there isn't a
# strong use case for dynamically subclassing the Gin-wrapped classes
# resulting from @gin.register and gin.external_configurable() classes, so
# this tradeoff should be ok.
cls_meta = type(cls) # First, determine the metaclass of the given class.
# Now, we wrap the __call__ method on the metaclass. It is important to note
# that we ignore the first parameter (normally would be an instance of the
# metaclass we're creating, e.g. `decorated_cls` from below) and just pass
# `cls` directly to cls_meta's `__call__`. This ensures that we construct
# an actual instance of `cls`, and not our dynamically created subclass from
# below, which enables pickling instances of the class.
meta_call = lambda _, *a, **kw: cls_meta.__call__(cls, *a, **kw)
# We decorate our wrapped metaclass __call__ with Gin's wrapper.
decorated_call = decorator(_ensure_wrappability(meta_call))
# And now construct a new metaclass, subclassing the one from `cls`,
# supplying our decorated `__call__`. Most often this is just subclassing
# Python's `type`, but when `cls` has a custom metaclass set, this ensures
# that it will continue to work properly.
decorating_meta = type(cls_meta.__name__, (cls_meta,), {
'__call__': decorated_call,
})
# Now we construct our class. This is a subclass of `cls`, but with no
# overrides, since injecting/intercepting parameters is all handled in the
# metaclass's `__call__` method.
decorated_class = decorating_meta(cls.__name__, (cls,), {})
decorated_class.__name__ = cls.__name__
decorated_class.__doc__ = cls.__doc__
decorated_class.__qualname__ = cls.__qualname__
decorated_class.__module__ = cls.__module__
else:
cls = fn_or_cls

decorated_fn = decorator(_ensure_wrappability(construction_fn))
if construction_fn.__name__ == '__new__':
decorated_fn = staticmethod(decorated_fn)
setattr(cls, construction_fn.__name__, decorated_fn)
return cls
# Here, we just decorate `__init__` or `__new__` directly, and mutate the
# original class definition to use the decorated version. This is simpler
# and permits reliable subclassing of @gin.configurable decorated classes.
decorated_class = cls
construction_fn = _find_class_construction_fn(decorated_class)
decorated_fn = decorator(_ensure_wrappability(construction_fn))
if construction_fn.__name__ == '__new__':
decorated_fn = staticmethod(decorated_fn)
setattr(decorated_class, construction_fn.__name__, decorated_fn)
return decorated_class


class Configurable(
Expand Down Expand Up @@ -918,7 +948,8 @@ def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist):
fn: The function that will be wrapped.
fn_or_cls: The original function or class being made configurable. This will
differ from `fn` when making a class configurable, in which case `fn` will
be the constructor/new function, while `fn_or_cls` will be the class.
be the constructor/new function (or when proxying a class, the type's
`__call__` method), while `fn_or_cls` will be the class.
name: The name given to the configurable.
selector: The full selector of the configurable (name including any module
components).
Expand All @@ -931,10 +962,13 @@ def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist):
# At this point we have access to the final function to be wrapped, so we
# can cache a few things here.
fn_descriptor = "'{}' ('{}')".format(name, fn_or_cls)
signature_fn = fn_or_cls
if inspect.isclass(fn_or_cls):
signature_fn = _find_class_construction_fn(fn_or_cls)
signature_required_kwargs = _get_validated_required_kwargs(
fn, fn_descriptor, allowlist, denylist)
signature_fn, fn_descriptor, allowlist, denylist)
initial_configurable_defaults = _get_default_configurable_parameter_values(
fn, allowlist, denylist)
signature_fn, allowlist, denylist)

@functools.wraps(fn)
def gin_wrapper(*args, **kwargs):
Expand All @@ -947,7 +981,7 @@ def gin_wrapper(*args, **kwargs):
gin_bound_args = list(new_kwargs.keys())
scope_str = partial_scope_str

arg_names = _get_supplied_positional_parameter_names(fn, args)
arg_names = _get_supplied_positional_parameter_names(signature_fn, args)

for arg in args[len(arg_names):]:
if arg is REQUIRED:
Expand Down Expand Up @@ -1032,7 +1066,7 @@ def gin_wrapper(*args, **kwargs):

if missing_required_params:
missing_required_params = (
_order_by_signature(fn, missing_required_params))
_order_by_signature(signature_fn, missing_required_params))
err_str = 'Required bindings for `{}` not provided in config: {}'
minimal_selector = _REGISTRY.minimal_selector(selector)
err_str = err_str.format(minimal_selector, missing_required_params)
Expand All @@ -1047,7 +1081,7 @@ def gin_wrapper(*args, **kwargs):
except Exception as e: # pylint: disable=broad-except
err_str = ''
if isinstance(e, TypeError):
all_arg_names = _get_all_positional_parameter_names(fn)
all_arg_names = _get_all_positional_parameter_names(signature_fn)
if len(new_args) < len(all_arg_names):
unbound_positional_args = list(
set(all_arg_names[len(new_args):]) - set(new_kwargs))
Expand Down Expand Up @@ -1076,7 +1110,7 @@ def _make_configurable(fn_or_cls,
module=None,
allowlist=None,
denylist=None,
subclass=False):
avoid_class_mutation=False):
"""Wraps `fn_or_cls` to make it configurable.

Infers the configurable name from `fn_or_cls.__name__` if necessary, and
Expand All @@ -1094,9 +1128,10 @@ def _make_configurable(fn_or_cls,
is specified as part of `name`).
allowlist: An allowlisted set of parameter names to supply values for.
denylist: A denylisted set of parameter names not to supply values for.
subclass: If `fn_or_cls` is a class and `subclass` is `True`, decorate by
subclassing `fn_or_cls` and overriding its `__init__` method. If `False`,
replace the existing `__init__` with a decorated version.
avoid_class_mutation: If `fn_or_cls` is a class and `avoid_class_mutation`
is `True`, decorate by subclassing `fn_or_cls`'s metaclass and overriding
its `__call__` method. If `False`, replace the existing `__init__` or
`__new__` with a decorated version.

Returns:
A wrapped version of `fn_or_cls` that will take parameter values from the
Expand Down Expand Up @@ -1148,7 +1183,7 @@ def decorator(fn):
denylist)

decorated_fn_or_cls = _decorate_fn_or_cls(
decorator, fn_or_cls, subclass=subclass)
decorator, fn_or_cls, avoid_class_mutation=avoid_class_mutation)

_REGISTRY[selector] = Configurable(
decorated_fn_or_cls,
Expand Down Expand Up @@ -1310,7 +1345,7 @@ def external_configurable(fn_or_cls,
module=module,
allowlist=allowlist,
denylist=denylist,
subclass=True)
avoid_class_mutation=True)


def register(name_or_fn=None,
Expand Down Expand Up @@ -1402,14 +1437,14 @@ def __init__(self, param1, param2='a default value'):
name = name_or_fn

def perform_decoration(fn_or_cls):
# Register it as configurable but return the orinal fn_or_cls.
# Register it as configurable but return the original fn_or_cls.
_make_configurable(
fn_or_cls,
name=name,
module=module,
allowlist=allowlist,
denylist=denylist,
subclass=True)
avoid_class_mutation=True)
return fn_or_cls

if decoration_target:
Expand All @@ -1423,7 +1458,7 @@ def _config_str(configuration_object,
"""Print the configuration specified in configuration object.

Args:
configuration_object: Either OPERATIVE_CONFIG_ (operative config) or _CONFIG
configuration_object: Either _OPERATIVE_CONFIG (operative config) or _CONFIG
(all config, bound and unbound).
max_line_length: A (soft) constraint on the maximum length of a line in the
formatted string. Large nested structures will be split across lines, but
Expand Down
43 changes: 10 additions & 33 deletions tests/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,20 +310,6 @@ def __init__(self, kwarg1=None, kwarg2=None):
config.external_configurable(ExternalClass, 'module.ExternalConfigurable2')


@config.configurable
class ConfigurableExternalSubclass(configurable_external_class):
"""Subclassing an external configurable object.

This is a configurable subclass (of the configurable subclass implicitly
created by external_configurable) of the ExternalClass class.
"""

def __init__(self, kwarg1=None, kwarg2=None, kwarg3=None):
super(ConfigurableExternalSubclass, self).__init__(
kwarg1=kwarg1, kwarg2=kwarg2)
self.kwarg3 = kwarg3


class AbstractConfigurable(metaclass=abc.ABCMeta):

def __init__(self, kwarg1=None):
Expand Down Expand Up @@ -696,7 +682,9 @@ def testConfigurableSubclass(self):
# subclasses of the original class are not subclasses of the reference.
self.assertFalse(issubclass(sub_cls_ref, super_cls_ref))
self.assertNotIsInstance(sub_instance, super_cls_ref)
self.assertNotIsInstance(sub_instance, type(super_instance))
# But due to the fact that Gin's dynamic metaclass creates instances of the
# actual class (not the Gin subclass), instance checks work.
self.assertIsInstance(sub_instance, type(super_instance))

self.assertEqual(super_instance.kwarg1, 'one')
self.assertIsNone(super_instance.kwarg2)
Expand Down Expand Up @@ -770,33 +758,22 @@ def testImplicitlyScopedConfigurableClass(self):
self.assertEqual(scope2_instance.kwarg1, 'scope2arg1')
self.assertEqual(scope2_instance.kwarg2, 'scope2arg2')

def testImplicitlyScopedExternalConfigurableAndSubclass(self):
def testImplicitlyScopedExternalConfigurable(self):
config_str = """
configurable2.non_kwarg = @scope1/ExternalConfigurable
configurable2.kwarg1 = @scope2/ConfigurableExternalSubclass
scope1/ExternalConfigurable.kwarg1 = 'one'
scope2/ConfigurableExternalSubclass.kwarg2 = 'two'
scope2/ConfigurableExternalSubclass.kwarg3 = 'three'
"""
config.parse_config(config_str)
# pylint: disable=no-value-for-parameter
super_cls, sub_cls = configurable2()
cls, _ = configurable2()
# pylint: enable=no-value-for-parameter
self.assertTrue(issubclass(super_cls, ExternalClass))
self.assertTrue(issubclass(sub_cls, ExternalClass))
self.assertTrue(issubclass(sub_cls, ConfigurableExternalSubclass))

super_instance, sub_instance = super_cls(), sub_cls()
self.assertIsInstance(super_instance, ExternalClass)
self.assertIsInstance(sub_instance, ConfigurableExternalSubclass)
self.assertIsInstance(sub_instance, ExternalClass)
self.assertTrue(issubclass(cls, ExternalClass))

self.assertEqual(super_instance.kwarg1, 'one')
self.assertIsNone(super_instance.kwarg2)
instance = cls()
self.assertIsInstance(instance, ExternalClass)

self.assertIsNone(sub_instance.kwarg1)
self.assertEqual(sub_instance.kwarg2, 'two')
self.assertEqual(sub_instance.kwarg3, 'three')
self.assertEqual(instance.kwarg1, 'one')
self.assertIsNone(instance.kwarg2)

def testAbstractConfigurableSubclass(self):
config_str = """
Expand Down