Skip to content

Commit

Permalink
validate template (langchain-ai#865)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored and zachschillaci27 committed Mar 8, 2023
1 parent 84fbba1 commit a9fba55
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
14 changes: 9 additions & 5 deletions langchain/prompts/few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

validate_template: bool = True
"""Whether or not to try validating the template."""

@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""
Expand All @@ -61,11 +64,12 @@ def check_examples_and_selector(cls, values: Dict) -> Dict:
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent."""
check_valid_template(
values["prefix"] + values["suffix"],
values["template_format"],
values["input_variables"],
)
if values["validate_template"]:
check_valid_template(
values["prefix"] + values["suffix"],
values["template_format"],
values["input_variables"],
)
return values

class Config:
Expand Down
10 changes: 7 additions & 3 deletions langchain/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

validate_template: bool = True
"""Whether or not to try validating the template."""

@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
Expand Down Expand Up @@ -61,9 +64,10 @@ def format(self, **kwargs: Any) -> str:
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
check_valid_template(
values["template"], values["template_format"], values["input_variables"]
)
if values["validate_template"]:
check_valid_template(
values["template"], values["template_format"], values["input_variables"]
)
return values

@classmethod
Expand Down

0 comments on commit a9fba55

Please sign in to comment.