-
Notifications
You must be signed in to change notification settings - Fork 16.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from hwchase17/harrison/add_llms
add llm objects
- Loading branch information
Showing
14 changed files
with
199 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Wrappers on top of large language models.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""Base interface for large language models to expose.""" | ||
from abc import ABC, abstractmethod | ||
from typing import List, Optional | ||
|
||
|
||
class LLM(ABC): | ||
"""LLM wrapper should take in a prompt and return a string.""" | ||
|
||
@abstractmethod | ||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: | ||
"""Run the LLM on the given prompt and input.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""Wrapper around Cohere APIs.""" | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pydantic import BaseModel, Extra, root_validator | ||
|
||
from langchain.llms.base import LLM | ||
|
||
|
||
def remove_stop_tokens(text: str, stop: List[str]) -> str: | ||
"""Remove stop tokens, should they occur at end.""" | ||
for s in stop: | ||
if text.endswith(s): | ||
return text[: -len(s)] | ||
return text | ||
|
||
|
||
class Cohere(BaseModel, LLM): | ||
"""Wrapper around Cohere large language models.""" | ||
|
||
client: Any | ||
model: str = "gptd-instruct-tft" | ||
max_tokens: int = 256 | ||
temperature: float = 0.6 | ||
k: int = 0 | ||
p: int = 1 | ||
frequency_penalty: int = 0 | ||
presence_penalty: int = 0 | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
@root_validator() | ||
def template_is_valid(cls, values: Dict) -> Dict: | ||
"""Validate that api key python package exists in environment.""" | ||
if "COHERE_API_KEY" not in os.environ: | ||
raise ValueError( | ||
"Did not find Cohere API key, please add an environment variable" | ||
" `COHERE_API_KEY` which contains it." | ||
) | ||
try: | ||
import cohere | ||
|
||
values["client"] = cohere.Client(os.environ["COHERE_API_KEY"]) | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import cohere python package. " | ||
"Please it install it with `pip install cohere`." | ||
) | ||
return values | ||
|
||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: | ||
"""Call out to Cohere's generate endpoint.""" | ||
response = self.client.generate( | ||
model=self.model, | ||
prompt=prompt, | ||
max_tokens=self.max_tokens, | ||
temperature=self.temperature, | ||
k=self.k, | ||
p=self.p, | ||
frequency_penalty=self.frequency_penalty, | ||
presence_penalty=self.presence_penalty, | ||
stop_sequences=stop, | ||
) | ||
text = response.generations[0].text | ||
# If stop tokens are provided, Cohere's endpoint returns them. | ||
# In order to make this consistent with other endpoints, we strip them. | ||
if stop is not None: | ||
text = remove_stop_tokens(text, stop) | ||
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Wrapper around OpenAI APIs.""" | ||
import os | ||
from typing import Any, Dict, List, Mapping, Optional | ||
|
||
from pydantic import BaseModel, Extra, root_validator | ||
|
||
from langchain.llms.base import LLM | ||
|
||
|
||
class OpenAI(BaseModel, LLM): | ||
"""Wrapper around OpenAI large language models.""" | ||
|
||
client: Any | ||
model_name: str = "text-davinci-002" | ||
temperature: float = 0.7 | ||
max_tokens: int = 256 | ||
top_p: int = 1 | ||
frequency_penalty: int = 0 | ||
presence_penalty: int = 0 | ||
n: int = 1 | ||
best_of: int = 1 | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key python package exists in environment.""" | ||
if "OPENAI_API_KEY" not in os.environ: | ||
raise ValueError( | ||
"Did not find OpenAI API key, please add an environment variable" | ||
" `OPENAI_API_KEY` which contains it." | ||
) | ||
try: | ||
import openai | ||
|
||
values["client"] = openai.Completion | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import openai python package. " | ||
"Please it install it with `pip install openai`." | ||
) | ||
return values | ||
|
||
@property | ||
def default_params(self) -> Mapping[str, Any]: | ||
"""Get the default parameters for calling OpenAI API.""" | ||
return { | ||
"temperature": self.temperature, | ||
"max_tokens": self.max_tokens, | ||
"top_p": self.top_p, | ||
"frequency_penalty": self.frequency_penalty, | ||
"presence_penalty": self.presence_penalty, | ||
"n": self.n, | ||
"best_of": self.best_of, | ||
} | ||
|
||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: | ||
"""Call out to OpenAI's create endpoint.""" | ||
response = self.client.create( | ||
model=self.model_name, prompt=prompt, stop=stop, **self.default_params | ||
) | ||
return response["choices"][0]["text"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
-e . | ||
pytest | ||
pytest-dotenv | ||
black | ||
isort | ||
mypy | ||
flake8 | ||
flake8-docstrings | ||
cohere | ||
openai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""All tests for this package.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""All integration tests (tests that call out to an external API).""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""All integration tests for LLM objects.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Test Cohere API wrapper.""" | ||
|
||
from langchain.llms.cohere import Cohere | ||
|
||
|
||
def test_cohere_call() -> None: | ||
"""Test valid call to cohere.""" | ||
llm = Cohere(max_tokens=10) | ||
output = llm("Say foo:") | ||
assert isinstance(output, str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Test OpenAI API wrapper.""" | ||
|
||
from langchain.llms.openai import OpenAI | ||
|
||
|
||
def test_cohere_call() -> None: | ||
"""Test valid call to cohere.""" | ||
llm = OpenAI(max_tokens=10) | ||
output = llm("Say foo:") | ||
assert isinstance(output, str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""All unit tests (lightweight tests).""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""All unit tests for LLM objects.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Test helper functions for Cohere API.""" | ||
|
||
from langchain.llms.cohere import remove_stop_tokens | ||
|
||
|
||
def test_remove_stop_tokens() -> None: | ||
"""Test removing stop tokens when they occur.""" | ||
text = "foo bar baz" | ||
output = remove_stop_tokens(text, ["moo", "baz"]) | ||
assert output == "foo bar " | ||
|
||
|
||
def test_remove_stop_tokens_none() -> None: | ||
"""Test removing stop tokens when they do not occur.""" | ||
text = "foo bar baz" | ||
output = remove_stop_tokens(text, ["moo"]) | ||
assert output == "foo bar baz" |