Skip to content

Commit

Permalink
Harrison/prompt template prefix (#888)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Simmons <[email protected]>
  • Loading branch information
hwchase17 and g-simmons authored Feb 7, 2023
1 parent f95cedc commit e2b834e
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 3 deletions.
69 changes: 66 additions & 3 deletions docs/modules/prompts/examples/prompt_management.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
},
{
"cell_type": "markdown",
"id": "72f32ff2",
"id": "cc991ad2",
"metadata": {},
"source": [
"## From Template\n",
Expand All @@ -163,7 +163,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "2a81f2f8",
"id": "d0a0756c",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -174,7 +174,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "d365b144",
"id": "59046640",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -332,6 +332,69 @@
"print(prompt_from_string_examples.format(adjective=\"big\"))"
]
},
{
"cell_type": "markdown",
"id": "874b7575",
"metadata": {},
"source": [
"## Few Shot Prompts with Templates\n",
"We can also construct few shot prompt templates where the prefix and suffix themselves are prompt templates"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e710115f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import FewShotPromptWithTemplates"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5bf23a65",
"metadata": {},
"outputs": [],
"source": [
"prefix = PromptTemplate(input_variables=[\"content\"], template=\"This is a test about {content}.\")\n",
"suffix = PromptTemplate(input_variables=[\"new_content\"], template=\"Now you try to talk about {new_content}.\")\n",
"\n",
"prompt = FewShotPromptWithTemplates(\n",
" suffix=suffix,\n",
" prefix=prefix,\n",
" input_variables=[\"content\", \"new_content\"],\n",
" examples=examples,\n",
" example_prompt=example_prompt,\n",
" example_separator=\"\\n\",\n",
")\n",
"output = prompt.format(content=\"animals\", new_content=\"party\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d4036351",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"This is a test about animals.\n",
"Input: happy\n",
"Output: sad\n",
"Input: tall\n",
"Output: short\n",
"Now you try to talk about party.\n"
]
}
],
"source": [
"print(output)"
]
},
{
"cell_type": "markdown",
"id": "bf038596",
Expand Down
2 changes: 2 additions & 0 deletions langchain/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Prompt template classes."""
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.loading import load_prompt
from langchain.prompts.prompt import Prompt, PromptTemplate

Expand All @@ -10,4 +11,5 @@
"PromptTemplate",
"FewShotPromptTemplate",
"Prompt",
"FewShotPromptWithTemplates",
]
145 changes: 145 additions & 0 deletions langchain/prompts/few_shot_with_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, BasePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.prompt import PromptTemplate


class FewShotPromptWithTemplates(BasePromptTemplate, BaseModel):
"""Prompt template that contains few shot examples."""

examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""

example_selector: Optional[BaseExampleSelector] = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""

example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""

suffix: BasePromptTemplate
"""A PromptTemplate to put after the examples."""

input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""

example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""

prefix: Optional[BasePromptTemplate] = None
"""A PromptTemplate to put before the examples."""

template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

validate_template: bool = True
"""Whether or not to try validating the template."""

@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"
)

if examples is None and example_selector is None:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)

return values

@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent."""
input_variables = values["input_variables"]
expected_input_variables = set(values["suffix"].input_variables)
if values["prefix"] is not None:
expected_input_variables |= set(values["prefix"].input_variables)
missing_vars = expected_input_variables.difference(input_variables)
if missing_vars:
raise ValueError(
f"Got input_variables={input_variables}, but based on prefix/suffix "
f"expected {expected_input_variables}"
)
return values

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError

def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
# Get the examples to use.
examples = self._get_examples(**kwargs)
# Format the examples.
example_strings = [
self.example_prompt.format(**example) for example in examples
]
# Create the overall prefix.
if self.prefix is None:
prefix = ""
else:
prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
}
for k in prefix_kwargs.keys():
kwargs.pop(k)
prefix = self.prefix.format(**prefix_kwargs)

# Create the overall suffix
suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
}
for k in suffix_kwargs.keys():
kwargs.pop(k)
suffix = self.suffix.format(
**suffix_kwargs,
)

pieces = [prefix, *example_strings, suffix]
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)

@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "few_shot_with_templates"

def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the prompt."""
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)
40 changes: 40 additions & 0 deletions tests/unit_tests/prompts/test_few_shot_with_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Test few shot prompt template."""

from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.prompt import PromptTemplate

EXAMPLE_PROMPT = PromptTemplate(
input_variables=["question", "answer"], template="{question}: {answer}"
)


def test_prompttemplate_prefix_suffix() -> None:
"""Test that few shot works when prefix and suffix are PromptTemplates."""
prefix = PromptTemplate(
input_variables=["content"], template="This is a test about {content}."
)
suffix = PromptTemplate(
input_variables=["new_content"],
template="Now you try to talk about {new_content}.",
)

examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
prompt = FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=["content", "new_content"],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
)
output = prompt.format(content="animals", new_content="party")
expected_output = (
"This is a test about animals.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
assert output == expected_output

0 comments on commit e2b834e

Please sign in to comment.