Skip to content

Commit

Permalink
Merge pull request #2 from octoml/octoai
Browse files Browse the repository at this point in the history
add embeddings using instructor large endpoint
  • Loading branch information
AI-Bassem authored Jun 12, 2023
2 parents c9c31ce + 31eee4e commit dcedb0a
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 22 deletions.
82 changes: 82 additions & 0 deletions langchain/embeddings/octoai_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Module providing a wrapper around OctoAI Compute Service embedding models."""

from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from octoai import client

DEFAULT_EMBED_INSTRUCTION = "Represent this input: "
DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: "


class OctoAIEmbeddings(BaseModel, Embeddings):
"""
Wrapper around OctoAI Compute Service embedding models.
The environment variable ``OCTOAI_API_TOKEN`` should be set with your API token, or it can be passed
as a named parameter to the constructor.
"""
endpoint_url: Optional[str] = Field(
None, description="Endpoint URL to use.")
model_kwargs: Optional[dict] = Field(
None, description="Keyword arguments to pass to the model.")
octoai_api_token: Optional[str] = Field(
None, description="OCTOAI API Token")
embed_instruction: str = Field(
DEFAULT_EMBED_INSTRUCTION, description="Instruction to use for embedding documents.")
query_instruction: str = Field(
DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query.")

class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid

@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Ensure that the API key and python package exist in environment."""
values["octoai_api_token"] = get_from_dict_or_env(
values, "octoai_api_token", "OCTOAI_API_TOKEN")
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "ENDPOINT_URL")
return values

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Return the identifying parameters."""
return {"endpoint_url": self.endpoint_url, "model_kwargs": self.model_kwargs or {}}

def _compute_embeddings(self, texts: List[str], instruction: str) -> List[List[float]]:
"""Common functionality for compute embeddings using a OctoAI instruct model."""
embeddings = []
octoai_client = client.Client(token=self.octoai_api_token)

for text in texts:
parameter_payload = {
"sentence": str([text]),# for item in text]),
"instruction": str([instruction]),# for item in text]),
"parameters": self.model_kwargs or {}
}

try:
resp_json = octoai_client.infer(
self.endpoint_url, parameter_payload)
embedding = resp_json["embeddings"]
except Exception as e:
raise ValueError(
f"Error raised by the inference endpoint: {e}") from e

embeddings.append(embedding)

return embeddings

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute document embeddings using an OctoAI instruct model."""
texts = list(map(lambda x: x.replace("\n", " "), texts))
return self._compute_embeddings(texts, self.embed_instruction)

def embed_query(self, text: str) -> List[float]:
"""Compute query embedding using an OctoAI instruct model."""
text = text.replace("\n", " ")
return self._compute_embeddings([text], self.embed_instruction)

39 changes: 21 additions & 18 deletions langchain/llms/octoai_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@ class OctoAIEndpoint(LLM):
Example:
.. code-block:: python
from langchain.llms import OctoAIEndpoint
endpoint_url = (
"https://endpoint_name-account_id.octoai.cloud"
)
endpoint = OctoAIEndpoint(
endpoint_url=endpoint_url,
octoai_api_token="octoai-api-key"
from langchain.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
"""

endpoint_url: Optional[str] = None
Expand All @@ -45,7 +51,7 @@ class Config:

extra = Extra.forbid

@root_validator()
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
octoai_api_token = get_from_dict_or_env(
Expand Down Expand Up @@ -90,26 +96,23 @@ def _call(
"""
_model_kwargs = self.model_kwargs or {}

# payload json
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}

# HTTP headers for authorization
headers = {
"Authorization": f"Bearer {self.octoai_api_token}",
"Content-Type": "application/json",
}

# send request using octaoai sdk
try:
# Initialize the OctoAI client
octoai_client = client.Client(token=self.octoai_api_token)

# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]

except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}") from e
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e

if stop is not None:
# stop tokens when making calls to octoai.
# Apply stop tokens when making calls to OctoAI
text = enforce_stop_tokens(text, stop)

return text
23 changes: 23 additions & 0 deletions tests/integration_tests/embeddings/test_octoai_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Test octoai embeddings."""

from langchain.embeddings.octoai_embeddings import (
OctoAIEmbeddings,
)


def test_octoai_embedding_documents() -> None:
"""Test octoai embeddings."""
documents = ["foo bar"]
embedding = OctoAIEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768


def test_octoai_embedding_query() -> None:
"""Test octoai embeddings."""
document = "foo bar"
embedding = OctoAIEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1
assert len(output[0]) == 768
10 changes: 6 additions & 4 deletions tests/integration_tests/llms/test_octoai_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

from tests.integration_tests.llms.utils import assert_llm_equality


def test_octoai_endpoint_text_generation() -> None:
"""Test valid call to OctoAI text generation model."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={
"max_new_tokens": 512,
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
Expand All @@ -32,8 +33,9 @@ def test_octoai_endpoint_text_generation() -> None:
def test_octoai_endpoint_call_error() -> None:
"""Test valid call to OctoAI that errors."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={"max_new_tokens": -1})
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={"max_new_tokens": -1},
)
with pytest.raises(ValueError):
llm("Which state is Los Angeles in?")

Expand All @@ -43,7 +45,7 @@ def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={
"max_new_tokens": 512,
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
Expand Down

0 comments on commit dcedb0a

Please sign in to comment.