diff --git a/src/autora/state.py b/src/autora/state.py index 4de7e8a7..fdb8f4be 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -722,7 +722,7 @@ def _append(a: List[T], b: T) -> List[T]: return a + [b] -def inputs_from_state(f): +def inputs_from_state(f, input_mapping: Dict = {}): """Decorator to make target `f` into a function on a `State` and `**kwargs`. This wrapper makes it easier to pass arguments to a function from a State. @@ -732,6 +732,7 @@ def inputs_from_state(f): Args: f: a function with arguments that could be fields on a `State` and that returns a `Delta`. + input_mapping: a dict that maps the input arguments of the function to the state fields Returns: a version of `f` which takes and returns `State` objects. @@ -758,6 +759,22 @@ def inputs_from_state(f): >>> experimentalist(U(conditions=[101,102,103,104])) [111, 112, 113, 114] + If our function uses a different keyword argument than the state field, we can use + the input mapping: + >>> def experimentalist_(X): + ... new_conditions = [x + 10 for x in X] + ... return new_conditions + >>> experimentalist_on_state = inputs_from_state(experimentalist_, {'X': 'conditions'}) + >>> experimentalist_on_state(U(conditions=[1,2,3,4])) + [11, 12, 13, 14] + + Both also work with the `State` as UserDict. Here, we use the StandardState + >>> experimentalist(StandardState(conditions=[1, 2, 3, 4])) + [11, 12, 13, 14] + + >>> experimentalist_on_state(StandardState(conditions=[1, 2, 3, 4])) + [11, 12, 13, 14] + A dictionary can be returned and used: >>> @inputs_from_state ... def returns_a_dictionary(conditions): @@ -831,6 +848,14 @@ def inputs_from_state(f): >>> experimentalist(U(conditions=[1,2,3,4]), offset=2) [3, 4, 5, 6] + The same is true, if we don't provide a mapping for arguments: + >>> def experimentalist_(X, offset): + ... new_conditions = [x + offset for x in X] + ... return new_conditions + >>> experimentalist_on_state = inputs_from_state(experimentalist_, {'X': 'conditions'}) + >>> experimentalist_on_state(StandardState(conditions=[1,2,3,4]), offset=2) + [3, 4, 5, 6] + The state itself is passed through if the inner function requests the `state`: >>> @inputs_from_state ... def function_which_needs_whole_state(state, conditions): @@ -843,7 +868,16 @@ def inputs_from_state(f): """ # Get the set of parameter names from function f's signature + + reversed_mapping = {v: k for k, v in input_mapping.items()} + parameters_ = set(inspect.signature(f).parameters.keys()) + missing_func_params = set(input_mapping.keys()).difference(parameters_) + if missing_func_params: + raise ValueError( + f"The following keys in input_state_mapping are not parameters of the function: " + f"{missing_func_params}" + ) @wraps(f) def _f(state_: S, /, **kwargs) -> S: @@ -853,9 +887,21 @@ def _f(state_: S, /, **kwargs) -> S: if is_dataclass(state_): from_state = parameters_.intersection({i.name for i in fields(state_)}) arguments_from_state = {k: getattr(state_, k) for k in from_state} + from_state_input_mapping = { + reversed_mapping.get(f.name, f.name): getattr(state_, f.name) + for f in fields(state_) + if reversed_mapping.get(f.name, f.name) in parameters_ + } + arguments_from_state.update(from_state_input_mapping) elif isinstance(state_, UserDict): from_state = parameters_.intersection(set(state_.keys())) arguments_from_state = {k: state_[k] for k in from_state} + from_state_input_mapping = { + reversed_mapping.get(key, key): state_[key] + for key in state_.keys() + if reversed_mapping.get(key, key) in parameters_ + } + arguments_from_state.update(from_state_input_mapping) if "state" in parameters_: arguments_from_state["state"] = state_ arguments = dict(arguments_from_state, **kwargs) @@ -1134,7 +1180,9 @@ def _f(state_: S, **kwargs) -> S: def on_state( - function: Optional[Callable] = None, output: Optional[Sequence[str]] = None + function: Optional[Callable] = None, + input_mapping: Dict = {}, + output: Optional[Sequence[str]] = None, ): """Decorator (factory) to make target `function` into a function on a `State` and `**kwargs`. @@ -1143,6 +1191,7 @@ def on_state( Args: function: the function to be wrapped output: list specifying State field names for the return values of `function` + input_mapping: a dict that maps the keywords of the functions to the state fields Returns: @@ -1193,13 +1242,26 @@ def on_state( >>> add_six(W(conditions=[1, 2, 3, 4])) W(conditions=[7, 8, 9, 10]) + You can also declare an input-to-output mapping if the keyword arguments of the functions + don't match the state fields: + >>> @on_state(input_mapping={'X': 'conditions'}, output=["conditions"]) + ... def add_six(X): + ... return [x + 6 for x in X] + + >>> add_six(W(conditions=[1, 2, 3, 4])) + W(conditions=[7, 8, 9, 10]) + + This also works on the StandardState or other States that are defined as UserDicts: + >>> add_six(StandardState(conditions=[1, 2, 3,4])).conditions + [7, 8, 9, 10] + """ def decorator(f): f_ = f if output is not None: f_ = outputs_to_delta(*output)(f_) - f_ = inputs_from_state(f_) + f_ = inputs_from_state(f_, input_mapping) f_ = delta_to_state(f_) return f_