Skip to content

Commit

Permalink
improve import order to allow config updates, multiple targets
Browse files Browse the repository at this point in the history
Signed-off-by: Zen <[email protected]>
  • Loading branch information
desultory committed Jan 25, 2025
1 parent e20c21e commit 9c8f6b7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 30 deletions.
32 changes: 29 additions & 3 deletions src/ugrd/initramfs_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,31 @@ def _process_unprocessed(self, parameter_name: str) -> None:
self[parameter_name] = value


def _process_import_order(self, import_order: dict) -> None:
"""Processes the import order, setting the order requirements for import functions.
Ensures the order type is valid (before, after),
that the function is not ordered after itself.
Ensures that the same function/target is not in another order type.
"""
self.logger.debug("Processing import order:/n%s" % pretty_print(import_order))
order_types = ["before", "after"]
for order_type, order_dict in import_order.items():
if order_type not in order_types:
raise ValueError("Invalid import order type: %s" % order_type)
for function in order_dict:
targets = order_dict[function]
if not isinstance(targets, list):
targets = [targets]
if function in targets:
raise ValueError("Function cannot be ordered after itself: %s" % function)
for other_target in[self["import_order"].get(ot, {}) for ot in order_types if ot != order_type]:
if function in other_target and any(target in other_target[function] for target in targets):
raise ValueError("Function cannot be ordered in multiple types: %s" % function)
order_dict[function] = targets

self["import_order"].update(import_order)
self.logger.debug("Registered import order requirements: %s" % import_order)

def _process_import_functions(self, module, functions: list) -> list[Callable]:
"""Processes defined import functions, importing them and adding them to the returned list.
the 'function' key is required if dicts are used,
Expand All @@ -207,11 +232,12 @@ def _process_import_functions(self, module, functions: list) -> list[Callable]:
case "dict":
if "function" not in f:
raise ValueError("Function key not found in import dict: %s" % functions)
function_list.append(getattr(module, f["function"]))
func = f["function"]
function_list.append(getattr(module, func))
if "before" in f:
self["import_order"]["before"][f["function"]] = f["before"]
self["import_order"] = {"before": {func: f["before"]}}
if "after" in f:
self["import_order"]["after"][f["function"]] = f["after"]
self["import_order"] = {"after": {func: f["after"]}}
case _:
raise ValueError("Invalid type for import function: %s" % type(f))
return function_list
Expand Down
57 changes: 30 additions & 27 deletions src/ugrd/initramfs_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,43 +160,46 @@ def sort_hook_functions(self, hook: str) -> None:
if not func_names:
return self.logger.debug("No functions for hook: %s" % hook)

[before.pop(k) for k, v in before.copy().items() if k not in func_names or v not in func_names]
[after.pop(k) for k, v in after.copy().items() if k not in func_names or v not in func_names]
[before.pop(k) for k, v in before.copy().items() if k not in func_names or not any(subv in func_names for subv in v)]
[after.pop(k) for k, v in after.copy().items() if k not in func_names or not any(subv in func_names for subv in v)]

if not before and not after:
return self.logger.debug("No import order specified for hook: %s" % hook)

iterations = len(func_names) * 2 # Prevent infinite loops
while iterations:
def iter_order(order, direction):
# Iterate over all before/after functions
# If the function is not in the correct position, move it
# Use the index of the function to determine the position
# Move the function in the imports list as well
iterations -= 1
changed = False
for func_name, before_func in before.items():
func_index = func_names.index(func_name)
before_index = func_names.index(before_func)
if func_index < before_index:
self.logger.debug("Moving %s before %s" % (func_name, before_func))
func_names.insert(before_index, func_names.pop(func_index))
self["imports"][hook].insert(before_index, self["imports"][hook].pop(func_index))
changed = True
break
else:
self.logger.debug("Function %s is already before %s" % (func_name, before_func))
for func_name, after_func in after.items():

for func_name, other_funcs in order.items():
func_index = func_names.index(func_name)
after_index = func_names.index(after_func)
if func_index > after_index:
self.logger.debug("Moving %s after %s" % (func_name, after_func))
func_names.insert(after_index, func_names.pop(func_index))
self["imports"][hook].insert(after_index, self["imports"][hook].pop(func_index))
changed = True
break
else:
self.logger.debug("Function %s is already after %s" % (func_name, after_func))
if not changed:
assert func_index >= 0, "Function not found in import list: %s" % func_name
for func in other_funcs:
other_index = func_names.index(func)
assert other_index >= 0, "Function not found in import list: %s" % func
def reorder_func(direction):
self.logger.debug("Moving %s %s %s" % (func_name, direction, func))
func_names.insert(other_index, func_names.pop(func_index))
self["imports"][hook].insert(other_index, self["imports"][hook].pop(func_index))

if direction == "before":
if other_index < func_index:
reorder_func("before")
changed = True
elif direction == "after":
if other_index > func_index:
reorder_func("after")
changed = True
else:
raise ValueError("Invalid direction: %s" % direction)
return changed

iterations = len(func_names) * 2 # Prevent infinite loops
while iterations:
iterations -= 1
if not any([iter_order(before, "before"), iter_order(after, "after")]):
break
else:
self.logger.error("Import list: %s" % func_names)
Expand Down

0 comments on commit 9c8f6b7

Please sign in to comment.