Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Allow nested prompt templates #28024

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union

import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
Expand All @@ -36,7 +28,6 @@
if TYPE_CHECKING:
from langchain_core.documents import Document


FormatOutputType = TypeVar("FormatOutputType")


Expand Down Expand Up @@ -260,27 +251,47 @@ async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
"""
return self.format_prompt(**kwargs)

def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
def partial(
self, **kwargs: Union[str, Callable[[], str], BasePromptTemplate]
) -> BasePromptTemplate:
"""Return a partial of the prompt template.

Args:
kwargs: Union[str, Callable[[], str], partial variables to set.
kwargs: Union[str, Callable[[], str], BasePromptTemplate],
partial variables to set.

Returns:
BasePromptTemplate: A partial of the prompt template.
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
input_vars = set(self.input_variables).difference(kwargs)
partial_vars = {}
for key, partial_var in kwargs.items():
if isinstance(partial_var, BasePromptTemplate):
# Prepare partial arguments, excluding the current key
new_kwargs = kwargs.copy()
new_kwargs.pop(key)
partial_var = partial_var.partial(**new_kwargs)
partial_vars[key] = partial_var
prompt_dict.update(
{
"input_variables": list(input_vars),
"partial_variables": {**kwargs, **partial_vars},
}
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we losing self.partial_variables, here?

return type(self)(**prompt_dict)

def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
}
partial_kwargs = {}
for k, v in self.partial_variables.items():
if isinstance(v, BasePromptTemplate):
# Propagate partial variables and kwargs to nested prompt templates
partial_kwargs[k] = v.format(**kwargs)
elif callable(v):
partial_kwargs[k] = v()
else:
partial_kwargs[k] = v
return {**partial_kwargs, **kwargs}

@abstractmethod
Expand Down
13 changes: 13 additions & 0 deletions libs/core/langchain_core/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pydantic import BaseModel, model_validator

from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
Expand Down Expand Up @@ -104,12 +105,24 @@ def pre_init_validation(cls, values: dict) -> Any:
)

if values["template_format"]:
# Collect nested partial variables from
# BasePromptTemplate instances in partial_variables
nested_partial_vars = {
key
for partial_var in values["partial_variables"].values()
if isinstance(partial_var, BasePromptTemplate)
for key in partial_var.partial_variables
}

# Filter template variables based on
# partial_variables and nested_partial_vars
values["input_variables"] = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only point I found a bit counterintuitive. Imo if the user passes input_variables to the PromptTemplate we should not overwrite them with values from the template. This complicates the logic as I had to make sure that the nested partial variables are correctly excluded, to not break backward compatibility.

var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
and var not in nested_partial_vars
]

return values
Expand Down
56 changes: 54 additions & 2 deletions libs/core/tests/unit_tests/prompts/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,58 @@ def test_partial() -> None:
assert result == "This is a foo test."


def test_nested_prompt_template_as_partial() -> None:
"""Test prompt with PromptTemplate as partial variable."""
template_nested = "{bar}"
prompt_nested = PromptTemplate(input_variables=["bar"], template=template_nested)

template = "This is a {foo} test."
prompt = PromptTemplate(input_variables=["foo"], template=template)
assert prompt.template == template
assert prompt.input_variables == ["foo"]

new_prompt = prompt.partial(foo=prompt_nested)
assert new_prompt.input_variables == []
assert new_prompt.partial_variables["foo"].input_variables == ["bar"]
assert new_prompt.partial_variables["foo"].partial_variables == {}
result = new_prompt.format(bar="bar")
assert result == "This is a bar test."

new_prompt = prompt.partial(foo=prompt_nested, bar="bar")
assert new_prompt.input_variables == []
assert new_prompt.partial_variables["foo"].input_variables == []
assert new_prompt.partial_variables["foo"].partial_variables == {"bar": "bar"}
result = new_prompt.format()
assert result == "This is a bar test."


def test_nested_prompt_template_with_shared_variable() -> None:
"""Test prompt with PromptTemplate as partial variable, sharing another variable."""
template_nested = "{bar}"
prompt_nested = PromptTemplate(
input_variables=["bar", "foo"], template=template_nested
)

template = "This is a {foo} {bar} test."
prompt = PromptTemplate(input_variables=["foo", "bar"], template=template)
assert prompt.template == template
assert prompt.input_variables == ["bar", "foo"]

new_prompt = prompt.partial(foo=prompt_nested)
assert new_prompt.input_variables == ["bar"]
assert new_prompt.partial_variables["foo"].input_variables == ["bar"]
assert new_prompt.partial_variables["foo"].partial_variables == {}
result = new_prompt.format(bar="bar")
assert result == "This is a bar bar test."

new_prompt = prompt.partial(foo=prompt_nested, bar="bar")
assert new_prompt.input_variables == []
assert new_prompt.partial_variables["foo"].input_variables == []
assert new_prompt.partial_variables["foo"].partial_variables == {"bar": "bar"}
result = new_prompt.format()
assert result == "This is a bar bar test."


@pytest.mark.requires("jinja2")
def test_prompt_from_jinja2_template() -> None:
"""Test prompts can be constructed from a jinja2 template."""
Expand Down Expand Up @@ -508,7 +560,7 @@ def test_prompt_jinja2_missing_input_variables() -> None:

@pytest.mark.requires("jinja2")
def test_prompt_jinja2_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
"""Test warning is raised when there are too many input variables."""
template = "This is a {{ foo }} test."
input_variables = ["foo", "bar"]
with pytest.warns(UserWarning):
Expand All @@ -525,7 +577,7 @@ def test_prompt_jinja2_extra_input_variables() -> None:

@pytest.mark.requires("jinja2")
def test_prompt_jinja2_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
"""Test warning is raised when name of input variable is wrong."""
template = "This is a {{ foo }} test."
input_variables = ["bar"]
with pytest.warns(UserWarning):
Expand Down
Loading