-
Notifications
You must be signed in to change notification settings - Fork 16.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from octoml/octoai
OctoAI Endpoint
- Loading branch information
Showing
5 changed files
with
302 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## OctoAI Compute Service\n", | ||
"This example goes over how to use LangChain to interact with `OctoAI` [LLM endpoints](https://octoai.cloud/templates)\n", | ||
"## Environment setup\n", | ||
"\n", | ||
"To run our example app, there are four simple steps to take:\n", | ||
"\n", | ||
"1. Clone the MPT-7B demo template to your OctoAI account by visiting <https://octoai.cloud/templates/mpt-7b-demo> then clicking \"Clone Template.\" \n", | ||
" 1. If you want to use a different LLM model, you can also containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](doc:create-custom-endpoints-from-python-code) and [Create a Custom Endpoint from a Container](doc:create-custom-endpoints-from-a-container)\n", | ||
" \n", | ||
"2. Paste your Endpoint URL in the code cell below\n", | ||
"\n", | ||
"3. Get an API Token from [your OctoAI account page](https://octoai.cloud/settings).\n", | ||
" \n", | ||
"4. Paste your API key in in the code cell below" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n", | ||
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.llms.octoai_endpoint import OctoAIEndpoint\n", | ||
"from langchain import PromptTemplate, LLMChain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"template = \"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n Instruction:\\n{question}\\n Response: \"\"\"\n", | ||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 30, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"llm = OctoAIEndpoint(\n", | ||
" model_kwargs={\n", | ||
" \"max_new_tokens\": 200,\n", | ||
" \"temperature\": 0.75,\n", | ||
" \"top_p\": 0.95,\n", | ||
" \"repetition_penalty\": 1,\n", | ||
" \"seed\": None,\n", | ||
" \"stop\": [],\n", | ||
" },\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"'\\nLeonardo da Vinci was an Italian polymath and painter regarded by many as one of the greatest painters of all time. He is best known for his masterpieces including Mona Lisa, The Last Supper, and The Virgin of the Rocks. He was a draftsman, sculptor, architect, and one of the most important figures in the history of science. Da Vinci flew gliders, experimented with water turbines and windmills, and invented the catapult and a joystick-type human-powered aircraft control. He may have pioneered helicopters. As a scholar, he was interested in anatomy, geology, botany, engineering, mathematics, and astronomy.\\nOther painters and patrons claimed to be more talented, but Leonardo da Vinci was an incredibly productive artist, sculptor, engineer, anatomist, and scientist.'" | ||
] | ||
}, | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"question = \"Who was leonardo davinci?\"\n", | ||
"\n", | ||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n", | ||
"\n", | ||
"llm_chain.run(question)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "langchain", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.16" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "97697b63fdcee0a640856f91cb41326ad601964008c341809e43189d1cab1047" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""Wrapper around OctoAI APIs.""" | ||
from typing import Any, Dict, List, Mapping, Optional | ||
from pydantic import Extra, root_validator | ||
|
||
from langchain.callbacks.manager import CallbackManagerForLLMRun | ||
from langchain.llms.base import LLM | ||
from langchain.llms.utils import enforce_stop_tokens | ||
from langchain.utils import get_from_dict_or_env | ||
|
||
from octoai import client | ||
|
||
|
||
class OctoAIEndpoint(LLM): | ||
"""Wrapper around OctoAI Inference Endpoints. | ||
OctoAIEndpoint is a class to interact with OctoAI Compute Service large language model endpoints. | ||
To use, you should have the ``octoai`` python package installed, and the | ||
environment variable ``OCTOAI_API_TOKEN`` set with your API token, or pass | ||
it as a named parameter to the constructor. | ||
Example: | ||
.. code-block:: python | ||
from langchain.llms import OctoAIEndpoint | ||
endpoint_url = ( | ||
"https://endpoint_name-account_id.octoai.cloud" | ||
) | ||
endpoint = OctoAIEndpoint( | ||
endpoint_url=endpoint_url, | ||
octoai_api_token="octoai-api-key" | ||
) | ||
""" | ||
|
||
endpoint_url: Optional[str] = None | ||
"""Endpoint URL to use.""" | ||
|
||
model_kwargs: Optional[dict] = None | ||
"""Key word arguments to pass to the model.""" | ||
|
||
octoai_api_token: Optional[str] = None | ||
"""OCTOAI API Token""" | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
octoai_api_token = get_from_dict_or_env( | ||
values, "octoai_api_token", "OCTOAI_API_TOKEN" | ||
) | ||
values["endpoint_url"] = get_from_dict_or_env( | ||
values, "endpoint_url", "ENDPOINT_URL" | ||
) | ||
|
||
values["octoai_api_token"] = octoai_api_token | ||
return values | ||
|
||
@property | ||
def _identifying_params(self) -> Mapping[str, Any]: | ||
"""Get the identifying parameters.""" | ||
_model_kwargs = self.model_kwargs or {} | ||
return { | ||
**{"endpoint_url": self.endpoint_url}, | ||
**{"model_kwargs": _model_kwargs}, | ||
} | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of llm.""" | ||
return "octoai_endpoint" | ||
|
||
def _call( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
) -> str: | ||
"""Call out to OctoAI's inference endpoint. | ||
Args: | ||
prompt: The prompt to pass into the model. | ||
stop: Optional list of stop words to use when generating. | ||
Returns: | ||
The string generated by the model. | ||
""" | ||
_model_kwargs = self.model_kwargs or {} | ||
|
||
# payload json | ||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} | ||
|
||
# HTTP headers for authorization | ||
headers = { | ||
"Authorization": f"Bearer {self.octoai_api_token}", | ||
"Content-Type": "application/json", | ||
} | ||
|
||
# send request using octaoai sdk | ||
try: | ||
octoai_client = client.Client(token=self.octoai_api_token) | ||
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload) | ||
text = resp_json["generated_text"] | ||
|
||
except Exception as e: | ||
raise ValueError(f"Error raised by inference endpoint: {e}") from e | ||
|
||
if stop is not None: | ||
# stop tokens when making calls to octoai. | ||
text = enforce_stop_tokens(text, stop) | ||
|
||
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Test OctoAI API wrapper.""" | ||
|
||
import unittest | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from langchain.llms.loading import load_llm | ||
from langchain.llms.octoai_endpoint import OctoAIEndpoint | ||
|
||
from tests.integration_tests.llms.utils import assert_llm_equality | ||
|
||
def test_octoai_endpoint_text_generation() -> None: | ||
"""Test valid call to OctoAI text generation model.""" | ||
llm = OctoAIEndpoint( | ||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", | ||
model_kwargs={ | ||
"max_new_tokens": 512, | ||
"temperature": 0.75, | ||
"top_p": 0.95, | ||
"repetition_penalty": 1, | ||
"seed": None, | ||
"stop": [], | ||
}, | ||
) | ||
|
||
output = llm("Which state is Los Angeles in?") | ||
print(output) | ||
assert isinstance(output, str) | ||
|
||
|
||
def test_octoai_endpoint_call_error() -> None: | ||
"""Test valid call to OctoAI that errors.""" | ||
llm = OctoAIEndpoint( | ||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", | ||
model_kwargs={"max_new_tokens": -1}) | ||
with pytest.raises(ValueError): | ||
llm("Which state is Los Angeles in?") | ||
|
||
|
||
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None: | ||
"""Test saving/loading an OctoAIHub LLM.""" | ||
llm = OctoAIEndpoint( | ||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", | ||
model_kwargs={ | ||
"max_new_tokens": 512, | ||
"temperature": 0.75, | ||
"top_p": 0.95, | ||
"repetition_penalty": 1, | ||
"seed": None, | ||
"stop": [], | ||
}, | ||
) | ||
llm.save(file_path=tmp_path / "octoai.yaml") | ||
loaded_llm = load_llm(tmp_path / "octoai.yaml") | ||
assert_llm_equality(llm, loaded_llm) |