Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions gin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def drink(cocktail):
import copy
import enum
import functools
import importlib
import inspect
import logging
import os
Expand Down Expand Up @@ -156,6 +157,10 @@ def exit_scope(self):
# Keeps a set of module names that were dynamically imported via config files.
_IMPORTED_MODULES = set()

# Keeps a list of `config_parser.RegistrationBlock`s that were used to do
# in-file registrations.
_IN_FILE_REGISTRATIONS = []

# Maps `(scope, selector)` tuples to all configurable parameter values used
# during program execution (including default argument values).
_OPERATIVE_CONFIG = {}
Expand Down Expand Up @@ -746,8 +751,9 @@ def clear_config(clear_constants=False):
"""Clears the global configuration.

This clears any parameter values set by `bind_parameter` or `parse_config`, as
well as the set of dynamically imported modules. It does not remove any
configurable functions or classes from the registry of configurables.
well as the set of dynamically imported modules and in-file registrations. It
does not remove any configurable functions or classes from the registry of
configurables.

Args:
clear_constants: Whether to clear constants created by `constant`. Defaults
Expand All @@ -764,6 +770,7 @@ def clear_config(clear_constants=False):
for name, value in saved_constants.items():
constant(name, value)
_IMPORTED_MODULES.clear()
_IN_FILE_REGISTRATIONS.clear()
_OPERATIVE_CONFIG.clear()


Expand Down Expand Up @@ -1331,8 +1338,10 @@ def _make_configurable(fn_or_cls,
raise ValueError("Module '{}' is invalid.".format(module))

selector = module + '.' + name if module else name
if not _INTERACTIVE_MODE and selector in _REGISTRY:
err_str = ("A configurable matching '{}' already exists.\n\n"
if fn_or_cls in _INVERSE_REGISTRY:
pass # TODO: Fill in check for consistency.
elif not _INTERACTIVE_MODE and selector in _REGISTRY:
err_str = ("A different configurable matching '{}' already exists.\n\n"
'To allow re-registration of configurables in an interactive '
'environment, use:\n\n'
' gin.enter_interactive_mode()')
Expand Down Expand Up @@ -1677,7 +1686,16 @@ def sort_key(key_tuple):
formatted_statements = [
'import {}'.format(module) for module in sorted(_IMPORTED_MODULES)
]
if formatted_statements:
if _IMPORTED_MODULES:
formatted_statements.append('')

for registration_block in _IN_FILE_REGISTRATIONS:
module_str = registration_block.module
if registration_block.module_alias:
module_str += f' as {registration_block.module_alias}'
formatted_statements.append(f'register from {module_str}:')
for registration in registration_block.registrations:
formatted_statements.append(' ' + registration.name)
formatted_statements.append('')

macros = {}
Expand Down Expand Up @@ -1876,7 +1894,7 @@ def parse_config(bindings, skip_unknown=False):
imports.append(statement.module)
if skip_unknown:
try:
__import__(statement.module)
importlib.import_module(statement.module)
_IMPORTED_MODULES.add(statement.module)
except ImportError:
tb_len = len(traceback.extract_tb(sys.exc_info()[2]))
Expand All @@ -1892,12 +1910,20 @@ def parse_config(bindings, skip_unknown=False):
logging.info(log_str, *log_args)
else:
with utils.try_with_location(statement.location):
__import__(statement.module)
importlib.import_module(statement.module)
_IMPORTED_MODULES.add(statement.module)
elif isinstance(statement, config_parser.IncludeStatement):
with utils.try_with_location(statement.location):
nested_includes = parse_config_file(statement.filename, skip_unknown)
includes.append(nested_includes)
elif isinstance(statement, config_parser.RegistrationBlock):
with utils.try_with_location(statement.location):
module = importlib.import_module(statement.module)
for registration in statement.registrations:
with utils.try_with_location(registration.location):
fn_or_cls = getattr(module, registration.name)
register(fn_or_cls, module=statement.module_alias or statement.module)
_IN_FILE_REGISTRATIONS.append(statement)
else:
raise AssertionError('Unrecognized statement type {}.'.format(statement))
return includes, imports
Expand Down
Loading