diff --git a/gin/config.py b/gin/config.py index 301b722..a22c85d 100644 --- a/gin/config.py +++ b/gin/config.py @@ -191,7 +191,7 @@ def exit_scope(self): def _find_class_construction_fn(cls): """Find the first __init__ or __new__ method in the given class's MRO.""" - for base in type.mro(cls): # pytype: disable=wrong-arg-types + for base in cls.mro(): # pytype: disable=attribute-error if '__init__' in base.__dict__: return base.__init__ if '__new__' in base.__dict__: @@ -216,54 +216,84 @@ def _ensure_wrappability(fn): return fn -def _decorate_fn_or_cls(decorator, fn_or_cls, subclass=False): +def _decorate_fn_or_cls(decorator, fn_or_cls, avoid_class_mutation=False): """Decorate a function or class with the given decorator. When `fn_or_cls` is a function, applies `decorator` to the function and returns the (decorated) result. - When `fn_or_cls` is a class and the `subclass` parameter is `False`, this will - replace `fn_or_cls.__init__` with the result of applying `decorator` to it. - - When `fn_or_cls` is a class and `subclass` is `True`, this will subclass the - class, but with `__init__` defined to be the result of applying `decorator` to - `fn_or_cls.__init__`. The decorated class has metadata (docstring, name, and - module information) copied over from `fn_or_cls`. The goal is to provide a - decorated class the behaves as much like the original as possible, without - modifying it (for example, inspection operations using `isinstance` or - `issubclass` should behave the same way as on the original class). + When `fn_or_cls` is a class and the `avoid_class_mutation` parameter is + `False`, this will replace either `fn_or_cls.__init__` or `fn_or_cls.__new__` + (whichever is first implement in the class's MRO, with a preference for + `__init__`) with the result of applying `decorator` to it. + + When `fn_or_cls` is a class and `avoid_class_mutation` is `True`, this will + dynamically construct a subclass of `fn_or_cls` using a dynamically + constructed metaclass (which itself is a subclass of `fn_or_cls`'s metaclass). + The metaclass's `__call__` method is wrapped using `decorator` to + intercept/inject paramaters for class construction. The resulting subclass has + metadata (docstring, name, and module information) copied over from + `fn_or_cls`, and should behave like the original as much possible, without + modifying it (for example, inspection operations using `issubclass` should + behave the same way as on the original class). When constructed, an instance + of the original (undecorated) class is returned. Args: decorator: The decorator to use. fn_or_cls: The function or class to decorate. - subclass: Whether to decorate classes by subclassing. This argument is - ignored if `fn_or_cls` is not a class. + avoid_class_mutation: Whether to avoid class mutation using dynamic + subclassing. This argument is ignored if `fn_or_cls` is not a class. Returns: The decorated function or class. """ if not inspect.isclass(fn_or_cls): # pytype: disable=wrong-arg-types return decorator(_ensure_wrappability(fn_or_cls)) - - construction_fn = _find_class_construction_fn(fn_or_cls) - - if subclass: - - class DecoratedClass(fn_or_cls): - __doc__ = fn_or_cls.__doc__ - __module__ = fn_or_cls.__module__ - - DecoratedClass.__name__ = fn_or_cls.__name__ - DecoratedClass.__qualname__ = fn_or_cls.__qualname__ - cls = DecoratedClass + cls = fn_or_cls + if avoid_class_mutation: + # This approach enables @gin.register and gin.external_configurable(), and + # is compatible with pickling instances. However, we can't use it for + # @gin.configurable because the decorated class returned below interacts + # poorly with super() calls when subclassed. In general, there isn't a + # strong use case for dynamically subclassing the Gin-wrapped classes + # resulting from @gin.register and gin.external_configurable() classes, so + # this tradeoff should be ok. + cls_meta = type(cls) # First, determine the metaclass of the given class. + # Now, we wrap the __call__ method on the metaclass. It is important to note + # that we ignore the first parameter (normally would be an instance of the + # metaclass we're creating, e.g. `decorated_cls` from below) and just pass + # `cls` directly to cls_meta's `__call__`. This ensures that we construct + # an actual instance of `cls`, and not our dynamically created subclass from + # below, which enables pickling instances of the class. + meta_call = lambda _, *a, **kw: cls_meta.__call__(cls, *a, **kw) + # We decorate our wrapped metaclass __call__ with Gin's wrapper. + decorated_call = decorator(_ensure_wrappability(meta_call)) + # And now construct a new metaclass, subclassing the one from `cls`, + # supplying our decorated `__call__`. Most often this is just subclassing + # Python's `type`, but when `cls` has a custom metaclass set, this ensures + # that it will continue to work properly. + decorating_meta = type(cls_meta.__name__, (cls_meta,), { + '__call__': decorated_call, + }) + # Now we construct our class. This is a subclass of `cls`, but with no + # overrides, since injecting/intercepting parameters is all handled in the + # metaclass's `__call__` method. + decorated_class = decorating_meta(cls.__name__, (cls,), {}) + decorated_class.__name__ = cls.__name__ + decorated_class.__doc__ = cls.__doc__ + decorated_class.__qualname__ = cls.__qualname__ + decorated_class.__module__ = cls.__module__ else: - cls = fn_or_cls - - decorated_fn = decorator(_ensure_wrappability(construction_fn)) - if construction_fn.__name__ == '__new__': - decorated_fn = staticmethod(decorated_fn) - setattr(cls, construction_fn.__name__, decorated_fn) - return cls + # Here, we just decorate `__init__` or `__new__` directly, and mutate the + # original class definition to use the decorated version. This is simpler + # and permits reliable subclassing of @gin.configurable decorated classes. + decorated_class = cls + construction_fn = _find_class_construction_fn(decorated_class) + decorated_fn = decorator(_ensure_wrappability(construction_fn)) + if construction_fn.__name__ == '__new__': + decorated_fn = staticmethod(decorated_fn) + setattr(decorated_class, construction_fn.__name__, decorated_fn) + return decorated_class class Configurable( @@ -918,7 +948,8 @@ def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist): fn: The function that will be wrapped. fn_or_cls: The original function or class being made configurable. This will differ from `fn` when making a class configurable, in which case `fn` will - be the constructor/new function, while `fn_or_cls` will be the class. + be the constructor/new function (or when proxying a class, the type's + `__call__` method), while `fn_or_cls` will be the class. name: The name given to the configurable. selector: The full selector of the configurable (name including any module components). @@ -931,10 +962,13 @@ def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist): # At this point we have access to the final function to be wrapped, so we # can cache a few things here. fn_descriptor = "'{}' ('{}')".format(name, fn_or_cls) + signature_fn = fn_or_cls + if inspect.isclass(fn_or_cls): + signature_fn = _find_class_construction_fn(fn_or_cls) signature_required_kwargs = _get_validated_required_kwargs( - fn, fn_descriptor, allowlist, denylist) + signature_fn, fn_descriptor, allowlist, denylist) initial_configurable_defaults = _get_default_configurable_parameter_values( - fn, allowlist, denylist) + signature_fn, allowlist, denylist) @functools.wraps(fn) def gin_wrapper(*args, **kwargs): @@ -947,7 +981,7 @@ def gin_wrapper(*args, **kwargs): gin_bound_args = list(new_kwargs.keys()) scope_str = partial_scope_str - arg_names = _get_supplied_positional_parameter_names(fn, args) + arg_names = _get_supplied_positional_parameter_names(signature_fn, args) for arg in args[len(arg_names):]: if arg is REQUIRED: @@ -1032,7 +1066,7 @@ def gin_wrapper(*args, **kwargs): if missing_required_params: missing_required_params = ( - _order_by_signature(fn, missing_required_params)) + _order_by_signature(signature_fn, missing_required_params)) err_str = 'Required bindings for `{}` not provided in config: {}' minimal_selector = _REGISTRY.minimal_selector(selector) err_str = err_str.format(minimal_selector, missing_required_params) @@ -1047,7 +1081,7 @@ def gin_wrapper(*args, **kwargs): except Exception as e: # pylint: disable=broad-except err_str = '' if isinstance(e, TypeError): - all_arg_names = _get_all_positional_parameter_names(fn) + all_arg_names = _get_all_positional_parameter_names(signature_fn) if len(new_args) < len(all_arg_names): unbound_positional_args = list( set(all_arg_names[len(new_args):]) - set(new_kwargs)) @@ -1076,7 +1110,7 @@ def _make_configurable(fn_or_cls, module=None, allowlist=None, denylist=None, - subclass=False): + avoid_class_mutation=False): """Wraps `fn_or_cls` to make it configurable. Infers the configurable name from `fn_or_cls.__name__` if necessary, and @@ -1094,9 +1128,10 @@ def _make_configurable(fn_or_cls, is specified as part of `name`). allowlist: An allowlisted set of parameter names to supply values for. denylist: A denylisted set of parameter names not to supply values for. - subclass: If `fn_or_cls` is a class and `subclass` is `True`, decorate by - subclassing `fn_or_cls` and overriding its `__init__` method. If `False`, - replace the existing `__init__` with a decorated version. + avoid_class_mutation: If `fn_or_cls` is a class and `avoid_class_mutation` + is `True`, decorate by subclassing `fn_or_cls`'s metaclass and overriding + its `__call__` method. If `False`, replace the existing `__init__` or + `__new__` with a decorated version. Returns: A wrapped version of `fn_or_cls` that will take parameter values from the @@ -1148,7 +1183,7 @@ def decorator(fn): denylist) decorated_fn_or_cls = _decorate_fn_or_cls( - decorator, fn_or_cls, subclass=subclass) + decorator, fn_or_cls, avoid_class_mutation=avoid_class_mutation) _REGISTRY[selector] = Configurable( decorated_fn_or_cls, @@ -1310,7 +1345,7 @@ def external_configurable(fn_or_cls, module=module, allowlist=allowlist, denylist=denylist, - subclass=True) + avoid_class_mutation=True) def register(name_or_fn=None, @@ -1402,14 +1437,14 @@ def __init__(self, param1, param2='a default value'): name = name_or_fn def perform_decoration(fn_or_cls): - # Register it as configurable but return the orinal fn_or_cls. + # Register it as configurable but return the original fn_or_cls. _make_configurable( fn_or_cls, name=name, module=module, allowlist=allowlist, denylist=denylist, - subclass=True) + avoid_class_mutation=True) return fn_or_cls if decoration_target: @@ -1423,7 +1458,7 @@ def _config_str(configuration_object, """Print the configuration specified in configuration object. Args: - configuration_object: Either OPERATIVE_CONFIG_ (operative config) or _CONFIG + configuration_object: Either _OPERATIVE_CONFIG (operative config) or _CONFIG (all config, bound and unbound). max_line_length: A (soft) constraint on the maximum length of a line in the formatted string. Large nested structures will be split across lines, but diff --git a/tests/config_test.py b/tests/config_test.py index b33d1bc..c8bd720 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -310,20 +310,6 @@ def __init__(self, kwarg1=None, kwarg2=None): config.external_configurable(ExternalClass, 'module.ExternalConfigurable2') -@config.configurable -class ConfigurableExternalSubclass(configurable_external_class): - """Subclassing an external configurable object. - - This is a configurable subclass (of the configurable subclass implicitly - created by external_configurable) of the ExternalClass class. - """ - - def __init__(self, kwarg1=None, kwarg2=None, kwarg3=None): - super(ConfigurableExternalSubclass, self).__init__( - kwarg1=kwarg1, kwarg2=kwarg2) - self.kwarg3 = kwarg3 - - class AbstractConfigurable(metaclass=abc.ABCMeta): def __init__(self, kwarg1=None): @@ -696,7 +682,9 @@ def testConfigurableSubclass(self): # subclasses of the original class are not subclasses of the reference. self.assertFalse(issubclass(sub_cls_ref, super_cls_ref)) self.assertNotIsInstance(sub_instance, super_cls_ref) - self.assertNotIsInstance(sub_instance, type(super_instance)) + # But due to the fact that Gin's dynamic metaclass creates instances of the + # actual class (not the Gin subclass), instance checks work. + self.assertIsInstance(sub_instance, type(super_instance)) self.assertEqual(super_instance.kwarg1, 'one') self.assertIsNone(super_instance.kwarg2) @@ -770,33 +758,22 @@ def testImplicitlyScopedConfigurableClass(self): self.assertEqual(scope2_instance.kwarg1, 'scope2arg1') self.assertEqual(scope2_instance.kwarg2, 'scope2arg2') - def testImplicitlyScopedExternalConfigurableAndSubclass(self): + def testImplicitlyScopedExternalConfigurable(self): config_str = """ configurable2.non_kwarg = @scope1/ExternalConfigurable - configurable2.kwarg1 = @scope2/ConfigurableExternalSubclass scope1/ExternalConfigurable.kwarg1 = 'one' - scope2/ConfigurableExternalSubclass.kwarg2 = 'two' - scope2/ConfigurableExternalSubclass.kwarg3 = 'three' """ config.parse_config(config_str) # pylint: disable=no-value-for-parameter - super_cls, sub_cls = configurable2() + cls, _ = configurable2() # pylint: enable=no-value-for-parameter - self.assertTrue(issubclass(super_cls, ExternalClass)) - self.assertTrue(issubclass(sub_cls, ExternalClass)) - self.assertTrue(issubclass(sub_cls, ConfigurableExternalSubclass)) - - super_instance, sub_instance = super_cls(), sub_cls() - self.assertIsInstance(super_instance, ExternalClass) - self.assertIsInstance(sub_instance, ConfigurableExternalSubclass) - self.assertIsInstance(sub_instance, ExternalClass) + self.assertTrue(issubclass(cls, ExternalClass)) - self.assertEqual(super_instance.kwarg1, 'one') - self.assertIsNone(super_instance.kwarg2) + instance = cls() + self.assertIsInstance(instance, ExternalClass) - self.assertIsNone(sub_instance.kwarg1) - self.assertEqual(sub_instance.kwarg2, 'two') - self.assertEqual(sub_instance.kwarg3, 'three') + self.assertEqual(instance.kwarg1, 'one') + self.assertIsNone(instance.kwarg2) def testAbstractConfigurableSubclass(self): config_str = """