Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llm objects #4

Merged
merged 2 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: format lint tests
.PHONY: format lint tests integration_tests

format:
black .
Expand All @@ -11,4 +11,7 @@ lint:
mypy .

tests:
pytest tests
pytest tests/unit_tests

integration_tests:
pytest tests/integration_tests
1 change: 1 addition & 0 deletions langchain/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Wrappers on top of large language models."""
11 changes: 11 additions & 0 deletions langchain/llms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Base interface for large language models to expose."""
from abc import ABC, abstractmethod
from typing import List, Optional


class LLM(ABC):
"""LLM wrapper should take in a prompt and return a string."""

@abstractmethod
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
72 changes: 72 additions & 0 deletions langchain/llms/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Wrapper around Cohere APIs."""
import os
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.llms.base import LLM


def remove_stop_tokens(text: str, stop: List[str]) -> str:
"""Remove stop tokens, should they occur at end."""
for s in stop:
if text.endswith(s):
return text[: -len(s)]
return text


class Cohere(BaseModel, LLM):
"""Wrapper around Cohere large language models."""

client: Any
model: str = "gptd-instruct-tft"
max_tokens: int = 256
temperature: float = 0.6
k: int = 0
p: int = 1
frequency_penalty: int = 0
presence_penalty: int = 0

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Validate that api key python package exists in environment."""
if "COHERE_API_KEY" not in os.environ:
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it."
)
try:
import cohere

values["client"] = cohere.Client(os.environ["COHERE_API_KEY"])
except ImportError:
raise ValueError(
"Could not import cohere python package. "
"Please it install it with `pip install cohere`."
)
return values

def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Cohere's generate endpoint."""
response = self.client.generate(
model=self.model,
prompt=prompt,
max_tokens=self.max_tokens,
temperature=self.temperature,
k=self.k,
p=self.p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop_sequences=stop,
)
text = response.generations[0].text
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop is not None:
text = remove_stop_tokens(text, stop)
return text
65 changes: 65 additions & 0 deletions langchain/llms/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Wrapper around OpenAI APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.llms.base import LLM


class OpenAI(BaseModel, LLM):
"""Wrapper around OpenAI large language models."""

client: Any
model_name: str = "text-davinci-002"
temperature: float = 0.7
max_tokens: int = 256
top_p: int = 1
frequency_penalty: int = 0
presence_penalty: int = 0
n: int = 1
best_of: int = 1

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key python package exists in environment."""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it."
)
try:
import openai

values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
return values

@property
def default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"n": self.n,
"best_of": self.best_of,
}

def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to OpenAI's create endpoint."""
response = self.client.create(
model=self.model_name, prompt=prompt, stop=stop, **self.default_params
)
return response["choices"][0]["text"]
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
-e .
pytest
pytest-dotenv
black
isort
mypy
flake8
flake8-docstrings
cohere
openai
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All tests for this package."""
1 change: 1 addition & 0 deletions tests/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All integration tests (tests that call out to an external API)."""
1 change: 1 addition & 0 deletions tests/integration_tests/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All integration tests for LLM objects."""
10 changes: 10 additions & 0 deletions tests/integration_tests/llms/test_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Test Cohere API wrapper."""

from langchain.llms.cohere import Cohere


def test_cohere_call() -> None:
"""Test valid call to cohere."""
llm = Cohere(max_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
10 changes: 10 additions & 0 deletions tests/integration_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Test OpenAI API wrapper."""

from langchain.llms.openai import OpenAI


def test_cohere_call() -> None:
"""Test valid call to cohere."""
llm = OpenAI(max_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
1 change: 1 addition & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All unit tests (lightweight tests)."""
1 change: 1 addition & 0 deletions tests/unit_tests/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All unit tests for LLM objects."""
17 changes: 17 additions & 0 deletions tests/unit_tests/llms/test_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Test helper functions for Cohere API."""

from langchain.llms.cohere import remove_stop_tokens


def test_remove_stop_tokens() -> None:
"""Test removing stop tokens when they occur."""
text = "foo bar baz"
output = remove_stop_tokens(text, ["moo", "baz"])
assert output == "foo bar "


def test_remove_stop_tokens_none() -> None:
"""Test removing stop tokens when they do not occur."""
text = "foo bar baz"
output = remove_stop_tokens(text, ["moo"])
assert output == "foo bar baz"