diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 3fc82ec9febb9..7e7fa7b45bdf0 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -389,7 +389,7 @@ def add_ten(x: Dict[str, int]) -> Dict[str, int]: # returns {'input': 5, 'add_step': {'added': 15}} """ - mapper: RunnableParallel[dict[str, Any]] + mapper: RunnableParallel def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None: super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 211fab46833db..f9dff49c0b99d 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -63,6 +63,7 @@ ConfigurableFieldSingleOption, RouterRunnable, Runnable, + RunnableAssign, RunnableBinding, RunnableBranch, RunnableConfig, @@ -5413,3 +5414,14 @@ def test_schema_for_prompt_and_chat_model() -> None: "title": "PromptInput", "type": "object", } + + +def test_runnable_assign() -> None: + def add_ten(x: dict[str, int]) -> dict[str, int]: + return {"added": x["input"] + 10} + + mapper = RunnableParallel({"add_step": RunnableLambda(add_ten)}) + runnable_assign = RunnableAssign(mapper) + + result = runnable_assign.invoke({"input": 5}) + assert result == {"input": 5, "add_step": {"added": 15}}