From 0b9fa63b3dc613a604967e1594e7376ac2f3ff04 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 16 Oct 2022 20:40:25 -0700 Subject: [PATCH] enforce mypy and fix errors --- langchain/formatting.py | 12 ++++++++++-- pyproject.toml | 1 + tests/unit_tests/test_formatting.py | 6 +++--- tests/unit_tests/test_prompt.py | 10 +++++----- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/langchain/formatting.py b/langchain/formatting.py index 88f83d7893cc2..61c7c11641b61 100644 --- a/langchain/formatting.py +++ b/langchain/formatting.py @@ -1,17 +1,25 @@ """Utilities for formatting strings.""" from string import Formatter +from typing import Any, Mapping, Sequence, Union class StrictFormatter(Formatter): """A subclass of formatter that checks for extra keys.""" - def check_unused_args(self, used_args, args, kwargs): + def check_unused_args( + self, + used_args: Sequence[Union[int, str]], + args: Sequence, + kwargs: Mapping[str, Any], + ) -> None: """Check to see if extra parameters are passed.""" extra = set(kwargs).difference(used_args) if extra: raise KeyError(extra) - def vformat(self, format_string, args, kwargs): + def vformat( + self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] + ) -> str: """Check that no arguments are provided.""" if len(args) > 0: raise ValueError( diff --git a/pyproject.toml b/pyproject.toml index 84d4a1f503aa8..8eedb8d89e33c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,4 +3,5 @@ profile = "black" [tool.mypy] ignore_missing_imports = "True" +disallow_untyped_defs = "True" exclude = ["notebooks"] diff --git a/tests/unit_tests/test_formatting.py b/tests/unit_tests/test_formatting.py index 5db90e37ab920..168e580b7b9f2 100644 --- a/tests/unit_tests/test_formatting.py +++ b/tests/unit_tests/test_formatting.py @@ -4,7 +4,7 @@ from langchain.formatting import formatter -def test_valid_formatting(): +def test_valid_formatting() -> None: """Test formatting works as expected.""" template = "This is a {foo} test." output = formatter.format(template, foo="good") @@ -12,14 +12,14 @@ def test_valid_formatting(): assert output == expected_output -def test_does_not_allow_args(): +def test_does_not_allow_args() -> None: """Test formatting raises error when args are provided.""" template = "This is a {} test." with pytest.raises(ValueError): formatter.format(template, "good") -def test_does_not_allow_extra_kwargs(): +def test_does_not_allow_extra_kwargs() -> None: """Test formatting does not allow extra key word arguments.""" template = "This is a {foo} test." with pytest.raises(KeyError): diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py index 7e6ff31fb9d8d..ac06f7e4b0602 100644 --- a/tests/unit_tests/test_prompt.py +++ b/tests/unit_tests/test_prompt.py @@ -4,7 +4,7 @@ from langchain.prompt import Prompt -def test_prompt_valid(): +def test_prompt_valid() -> None: """Test prompts can be constructed.""" template = "This is a {foo} test." input_variables = ["foo"] @@ -13,15 +13,15 @@ def test_prompt_valid(): assert prompt.input_variables == input_variables -def test_prompt_missing_input_variables(): +def test_prompt_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" template = "This is a {foo} test." - input_variables = [] + input_variables: list = [] with pytest.raises(ValueError): Prompt(input_variables=input_variables, template=template) -def test_prompt_extra_input_variables(): +def test_prompt_extra_input_variables() -> None: """Test error is raised when there are too many input variables.""" template = "This is a {foo} test." input_variables = ["foo", "bar"] @@ -29,7 +29,7 @@ def test_prompt_extra_input_variables(): Prompt(input_variables=input_variables, template=template) -def test_prompt_wrong_input_variables(): +def test_prompt_wrong_input_variables() -> None: """Test error is raised when name of input variable is wrong.""" template = "This is a {foo} test." input_variables = ["bar"]