diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 31627184f5375..0612507bbccda 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -13,7 +13,7 @@ class LLMChain(Chain, BaseModel): prompt: Prompt llm: LLM - return_key: str = "text" + output_key: str = "text" class Config: """Configuration for this pydantic object.""" @@ -29,7 +29,7 @@ def input_keys(self) -> List[str]: @property def output_keys(self) -> List[str]: """Will always return text key.""" - return [self.return_key] + return [self.output_key] def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} @@ -39,8 +39,8 @@ def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: if "stop" in inputs: kwargs["stop"] = inputs["stop"] response = self.llm(prompt, **kwargs) - return {self.return_key: response} + return {self.output_key: response} def predict(self, **kwargs: Any) -> str: """More user-friendly interface for interacting with LLMs.""" - return self(kwargs)[self.return_key] + return self(kwargs)[self.output_key] diff --git a/langchain/chains/python.py b/langchain/chains/python.py new file mode 100644 index 0000000000000..b3b6f1b6fc776 --- /dev/null +++ b/langchain/chains/python.py @@ -0,0 +1,37 @@ +"""Chain that runs python code.""" +import sys +from io import StringIO +from typing import Dict, List + +from pydantic import BaseModel + +from langchain.chains.base import Chain + + +class PythonChain(Chain, BaseModel): + """Chain to run python code.""" + + input_key: str = "code" + output_key: str = "output" + + @property + def input_keys(self) -> List[str]: + """Expect input in `code` key.""" + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output in `output` key.""" + return [self.output_key] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + exec(inputs[self.input_key]) + sys.stdout = old_stdout + output = mystdout.getvalue() + return {self.output_key: output} + + def run(self, code: str) -> str: + """More user-friendly interface for interfacing with python.""" + return self({self.input_key: code})[self.output_key] diff --git a/tests/unit_tests/chains/test_llm.py b/tests/unit_tests/chains/test_llm.py index 545409c265828..4c350637c7744 100644 --- a/tests/unit_tests/chains/test_llm.py +++ b/tests/unit_tests/chains/test_llm.py @@ -10,7 +10,7 @@ def fake_llm_chain() -> LLMChain: """Fake LLM chain for testing purposes.""" prompt = Prompt(input_variables=["bar"], template="This is a {bar}:") - return LLMChain(prompt=prompt, llm=FakeLLM(), return_key="text1") + return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") def test_missing_inputs(fake_llm_chain: LLMChain) -> None: diff --git a/tests/unit_tests/chains/test_python.py b/tests/unit_tests/chains/test_python.py new file mode 100644 index 0000000000000..1677a76a47292 --- /dev/null +++ b/tests/unit_tests/chains/test_python.py @@ -0,0 +1,15 @@ +"""Test python chain.""" + +from langchain.chains.python import PythonChain + + +def test_functionality() -> None: + """Test correct functionality.""" + chain = PythonChain(input_key="code1", output_key="output1") + code = "print(1 + 1)" + output = chain({"code1": code}) + assert output == {"code1": code, "output1": "2\n"} + + # Test with the more user-friendly interface. + simple_output = chain.run(code) + assert simple_output == "2\n"