From 56f019279d75a1bb863e161f423d0fab3a1b0324 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Fri, 8 Sep 2023 16:33:49 -0700 Subject: [PATCH 01/14] Refactor Fireworks api and remove FireworksChat --- libs/langchain/langchain/llms/__init__.py | 2 +- libs/langchain/langchain/llms/fireworks.py | 384 +----------------- .../integration_tests/llms/test_fireworks.py | 122 +----- 3 files changed, 26 insertions(+), 482 deletions(-) diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index 8e835ea0a9108..041ba8e7c3c0a 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -44,7 +44,7 @@ from langchain.llms.deepsparse import DeepSparse from langchain.llms.edenai import EdenAI from langchain.llms.fake import FakeListLLM -from langchain.llms.fireworks import Fireworks, FireworksChat +from langchain.llms.fireworks import Fireworks from langchain.llms.forefrontai import ForefrontAI from langchain.llms.google_palm import GooglePalm from langchain.llms.gooseai import GooseAI diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 2b4b5f0ff95dd..00e390b920a2e 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,377 +1,35 @@ -"""Wrapper around Fireworks APIs""" -import json -import logging -from typing import ( - Any, - Dict, - List, - Optional, - Set, - Tuple, - Union, -) +import os +from typing import Any, Optional -import requests +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +import openai -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.llms.base import BaseLLM -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import Generation, LLMResult -from langchain.utils import get_from_dict_or_env -logger = logging.getLogger(__name__) +class Fireworks(LLM): + """Fireworks models.""" - -class BaseFireworks(BaseLLM): - """Wrapper around Fireworks large language models.""" - - model_id: str = Field("accounts/fireworks/models/llama-v2-7b-chat", alias="model") - """Model name to use.""" - temperature: float = 0.7 - """What sampling temperature to use.""" - max_tokens: int = 512 - """The maximum number of tokens to generate in the completion. - -1 returns as many tokens as possible given the prompt and - the models maximal context size.""" - top_p: float = 1 - """Total probability mass of tokens to consider at each step.""" - fireworks_api_key: Optional[str] = None - """Api key to use fireworks API""" - batch_size: int = 20 - """Batch size to use when passing multiple documents to generate.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None - """Timeout for requests to Fireworks completion API. Default is 600 seconds.""" - max_retries: int = 6 - """Maximum number of retries to make when generating.""" - - @property - def lc_secrets(self) -> Dict[str, str]: - return {"fireworks_api_key": "FIREWORKS_API_KEY"} - - @property - def lc_serializable(self) -> bool: - return True - - def __new__(cls, **data: Any) -> Any: - """Initialize the Fireworks object.""" - data.get("model_id", "") - return super().__new__(cls) - - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - values["fireworks_api_key"] = get_from_dict_or_env( - values, "fireworks_api_key", "FIREWORKS_API_KEY" - ) - return values - - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Call out to Fireworks endpoint with k unique prompts. - Args: - prompts: The prompts to pass into the model. - stop: Optional list of stop words to use when generating. - Returns: - The full LLM output. - """ - params = {"model": self.model_id} - params = {**params, **kwargs} - sub_prompts = self.get_batch_prompts(params, prompts, stop) - choices = [] - token_usage: Dict[str, int] = {} - _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} - for _prompts in sub_prompts: - response = completion_with_retry(self, prompt=prompts, **params) - choices.extend(response) - update_token_usage(_keys, response, token_usage) - - return self.create_llm_result(choices, prompts, token_usage) - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Call out to Fireworks endpoint async with k unique prompts.""" - params = {"model": self.model_id} - params = {**params, **kwargs} - sub_prompts = self.get_batch_prompts(params, prompts, stop) - choices = [] - token_usage: Dict[str, int] = {} - _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} - for _prompts in sub_prompts: - response = await acompletion_with_retry(self, prompt=_prompts, **params) - choices.extend(response) - update_token_usage(_keys, response, token_usage) - - return self.create_llm_result(choices, prompts, token_usage) - - def get_batch_prompts( - self, - params: Dict[str, Any], - prompts: List[str], - stop: Optional[List[str]] = None, - ) -> List[List[str]]: - """Get the sub prompts for llm call.""" - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - - sub_prompts = [ - prompts[i : i + self.batch_size] - for i in range(0, len(prompts), self.batch_size) - ] - return sub_prompts - - def create_llm_result( - self, choices: Any, prompts: List[str], token_usage: Dict[str, int] - ) -> LLMResult: - """Create the LLMResult from the choices and prompts.""" - generations = [] - - for i, _ in enumerate(prompts): - sub_choices = choices[i : (i + 1)] - generations.append( - [ - Generation( - text=choice, - ) - for choice in sub_choices - ] - ) - llm_output = {"token_usage": token_usage, "model_id": self.model_id} - return LLMResult(generations=generations, llm_output=llm_output) + model = "accounts/fireworks/models/llama-v2-13b-chat" + model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} + fireworks_api_url: Optional[str] = "https://api.fireworks.ai/inference/v1" + fireworks_api_key: Optional[str] = os.environ.get("FIREWORKS_API_KEY") @property def _llm_type(self) -> str: """Return type of llm.""" return "fireworks" - -class FireworksChat(BaseLLM): - """Wrapper around Fireworks Chat large language models. - To use, you should have the ``fireworksai`` python package installed, and the - environment variable ``FIREWORKS_API_KEY`` set with your API key. - Any parameters that are valid to be passed to the fireworks.create - call can be passed in, even if not explicitly saved on this class. - Example: - .. code-block:: python - from langchain.llms import FireworksChat - fireworkschat = FireworksChat(model_id=""llama-v2-13b-chat"") - """ - - model_id: str = "accounts/fireworks/models/llama-v2-7b-chat" - """Model name to use.""" - temperature: float = 0.7 - """What sampling temperature to use.""" - max_tokens: int = 512 - """The maximum number of tokens to generate in the completion. - -1 returns as many tokens as possible given the prompt and - the models maximal context size.""" - top_p: float = 1 - """Total probability mass of tokens to consider at each step.""" - fireworks_api_key: Optional[str] = None - max_retries: int = 6 - request_timeout: Optional[Union[float, Tuple[float, float]]] = None - """Timeout for requests to Fireworks completion API. Default is 600 seconds.""" - """Maximum number of retries to make when generating.""" - prefix_messages: List = Field(default_factory=list) - """Series of messages for Chat input.""" - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment""" - values["fireworks_api_key"] = get_from_dict_or_env( - values, "fireworks_api_key", "FIREWORKS_API_KEY" - ) - return values - - def _get_chat_params( - self, prompts: List[str], stop: Optional[List[str]] = None - ) -> Tuple: - if len(prompts) > 1: - raise ValueError( - f"FireworksChat currently only supports single prompt, got {prompts}" - ) - messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] - params: Dict[str, Any] = {**{"model": self.model_id}} - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - - return messages, params - - def _generate( + def _call( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompt: str, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> LLMResult: - messages, params = self._get_chat_params(prompts, stop) - params = {**params, **kwargs} - full_response = completion_with_retry(self, messages=messages, **params) - llm_output = { - "model_id": self.model_id, - } - return LLMResult( - generations=[[Generation(text=full_response[0])]], - llm_output=llm_output, + ) -> str: + response = openai.Completion.create( + api_base=self.fireworks_api_url, + api_key=self.fireworks_api_key, + model=self.model, + prompt=prompt, + **self.model_kwargs, ) - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - messages, params = self._get_chat_params(prompts, stop) - params = {**params, **kwargs} - full_response = await acompletion_with_retry(self, messages=messages, **params) - llm_output = { - "model_id": self.model_id, - } - return LLMResult( - generations=[[Generation(text=full_response[0])]], - llm_output=llm_output, - ) - - @property - def _llm_type(self) -> str: - """Return type of llm.""" - return "fireworks-chat" - - -class Fireworks(BaseFireworks): - """Wrapper around Fireworks large language models. - To use, you should have the ``fireworks`` python package installed, and the - environment variable ``FIREWORKS_API_KEY`` set with your API key. - Any parameters that are valid to be passed to the fireworks.create - call can be passed in, even if not explicitly saved on this class. - Example: - .. code-block:: python - from langchain.llms import fireworks - llm = Fireworks(model_id="llama-v2-13b") - """ - - -def update_token_usage( - keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] -) -> None: - """Update token usage.""" - _keys_to_use = keys.intersection(response) - for _key in _keys_to_use: - if _key not in token_usage: - token_usage[_key] = response["usage"][_key] - else: - token_usage[_key] += response["usage"][_key] - - -def execute( - prompt: str, - model: str, - api_key: Optional[str], - max_tokens: int = 256, - temperature: float = 0.0, - top_p: float = 1.0, -) -> Any: - """Execute LLM query""" - requestUrl = "https://api.fireworks.ai/inference/v1/completions" - requestBody = { - "model": model, - "prompt": prompt, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - } - requestHeaders = { - "Authorization": f"Bearer {api_key}", - "Accept": "application/json", - "Content-Type": "application/json", - } - response = requests.post(requestUrl, headers=requestHeaders, json=requestBody) - return response.text - - -def completion_with_retry( - llm: Union[BaseFireworks, FireworksChat], **kwargs: Any -) -> Any: - """Use tenacity to retry the completion call.""" - if "prompt" not in kwargs.keys(): - answers = [] - for i in range(len(kwargs["messages"])): - result = kwargs["messages"][i]["content"] - result = execute( - result, - kwargs["model"], - llm.fireworks_api_key, - llm.max_tokens, - llm.temperature, - llm.top_p, - ) - curr_string = json.loads(result)["choices"][0]["text"] - answers.append(curr_string) - else: - answers = [] - for i in range(len(kwargs["prompt"])): - result = kwargs["prompt"][i] - result = execute( - result, - kwargs["model"], - llm.fireworks_api_key, - llm.max_tokens, - llm.temperature, - llm.top_p, - ) - curr_string = json.loads(result)["choices"][0]["text"] - answers.append(curr_string) - return answers - - -async def acompletion_with_retry( - llm: Union[BaseFireworks, FireworksChat], **kwargs: Any -) -> Any: - """Use tenacity to retry the async completion call.""" - if "prompt" not in kwargs.keys(): - answers = [] - for i in range(len(kwargs["messages"])): - result = kwargs["messages"][i]["content"] - result = execute( - result, - kwargs["model"], - llm.fireworks_api_key, - llm.max_tokens, - llm.temperature, - ) - curr_string = json.loads(result)["choices"][0]["text"] - answers.append(curr_string) - else: - answers = [] - for i in range(len(kwargs["prompt"])): - result = kwargs["prompt"][i] - result = execute( - result, - kwargs["model"], - llm.fireworks_api_key, - llm.max_tokens, - llm.temperature, - ) - curr_string = json.loads(result)["choices"][0]["text"] - answers.append(curr_string) - return answers + return response["choices"][0]["text"] diff --git a/libs/langchain/tests/integration_tests/llms/test_fireworks.py b/libs/langchain/tests/integration_tests/llms/test_fireworks.py index e0dcb4fe43308..188056c9e5978 100644 --- a/libs/langchain/tests/integration_tests/llms/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/llms/test_fireworks.py @@ -1,30 +1,17 @@ """Test Fireworks AI API Wrapper.""" -from pathlib import Path - -import pytest - from langchain import LLMChain, PromptTemplate -from langchain.chains import RetrievalQA -from langchain.document_loaders import TextLoader -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.llms import OpenAIChat -from langchain.llms.fireworks import Fireworks, FireworksChat -from langchain.llms.loading import load_llm +from langchain.llms.fireworks import Fireworks from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, ) from langchain.schema import LLMResult -from langchain.text_splitter import CharacterTextSplitter -from langchain.vectorstores import DeepLake def test_fireworks_call() -> None: """Test valid call to fireworks.""" - llm = Fireworks( - model_id="accounts/fireworks/models/fireworks-llama-v2-13b-chat", max_tokens=900 - ) - output = llm("What is the weather in NYC") + llm = Fireworks() + output = llm("Who's the best quarterback in the NFL?") assert isinstance(output, str) @@ -43,36 +30,10 @@ def test_fireworks_in_chain() -> None: assert isinstance(output, str) -@pytest.mark.asyncio -async def test_openai_chat_async_generate() -> None: - """Test async chat.""" - llm = OpenAIChat(max_tokens=10) - output = await llm.agenerate(["Hello, how are you?"]) - assert isinstance(output, LLMResult) - - def test_fireworks_model_param() -> None: """Tests model parameters for Fireworks""" llm = Fireworks(model="foo") - assert llm.model_id == "foo" - llm = Fireworks(model_id="foo") - assert llm.model_id == "foo" - - -def test_fireworkschat_model_param() -> None: - """Tests model parameters for FireworksChat""" - llm = FireworksChat(model="foo") - assert llm.model_id == "foo" - llm = FireworksChat(model_id="foo") - assert llm.model_id == "foo" - - -def test_saving_loading_llm(tmp_path: Path) -> None: - """Test saving/loading an Fireworks LLM.""" - llm = Fireworks(max_tokens=10) - llm.save(file_path=tmp_path / "fireworks.yaml") - loaded_llm = load_llm(tmp_path / "fireworks.yaml") - assert loaded_llm == llm + assert llm.model == "foo" def test_fireworks_multiple_prompts() -> None: @@ -82,78 +43,3 @@ def test_fireworks_multiple_prompts() -> None: assert isinstance(output, LLMResult) assert isinstance(output.generations, list) assert len(output.generations) == 2 - - -def test_fireworks_chat() -> None: - """Test FireworksChat.""" - llm = FireworksChat() - output = llm("Name me 3 quick facts about the New England Patriots") - assert isinstance(output, str) - - -async def test_fireworks_agenerate() -> None: - llm = Fireworks() - output = await llm.agenerate(["I'm a pickle", "I'm a pickle"]) - assert isinstance(output, LLMResult) - assert isinstance(output.generations, list) - assert len(output.generations) == 2 - - -async def test_fireworkschat_agenerate() -> None: - llm = FireworksChat(max_tokens=10) - output = await llm.agenerate(["Hello, how are you?"]) - assert isinstance(output, LLMResult) - assert isinstance(output.generations, list) - assert len(output.generations) == 1 - - -def test_fireworkschat_chain() -> None: - embeddings = OpenAIEmbeddings() - - loader = TextLoader( - "[workspace]/langchain-internal/docs/extras/modules/state_of_the_union.txt" - ) - documents = loader.load() - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - docs = text_splitter.split_documents(documents) - - embeddings = OpenAIEmbeddings() - - db = DeepLake( - dataset_path="./my_deeplake/", embedding_function=embeddings, overwrite=True - ) - db.add_documents(docs) - - query = "What did the president say about Ketanji Brown Jackson" - docs = db.similarity_search(query) - - qa = RetrievalQA.from_chain_type( - llm=FireworksChat(), - chain_type="stuff", - retriever=db.as_retriever(), - ) - query = "What did the president say about Ketanji Brown Jackson" - output = qa.run(query) - assert isinstance(output, str) - - -_EXPECTED_NUM_TOKENS = { - "accounts/fireworks/models/fireworks-llama-v2-13b": 17, - "accounts/fireworks/models/fireworks-llama-v2-7b": 17, - "accounts/fireworks/models/fireworks-llama-v2-13b-chat": 17, - "accounts/fireworks/models/fireworks-llama-v2-7b-chat": 17, -} - -_MODELS = models = [ - "accounts/fireworks/models/fireworks-llama-v2-13b", - "accounts/fireworks/models/fireworks-llama-v2-7b", - "accounts/fireworks/models/fireworks-llama-v2-13b-chat", - "accounts/fireworks/models/fireworks-llama-v2-7b-chat", -] - - -@pytest.mark.parametrize("model", _MODELS) -def test_fireworks_get_num_tokens(model: str) -> None: - """Test get_tokens.""" - llm = Fireworks(model=model) - assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model] From 681880ce530567ee3aaf2f198eb2219e76cb7bde Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Fri, 8 Sep 2023 17:17:55 -0700 Subject: [PATCH 02/14] Add chat fireworks --- .../langchain/chat_models/fireworks.py | 63 ++++++++++++++++ .../chat_models/test_fireworks.py | 73 +++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 libs/langchain/langchain/chat_models/fireworks.py create mode 100644 libs/langchain/tests/integration_tests/chat_models/test_fireworks.py diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py new file mode 100644 index 0000000000000..e34e66525ce9a --- /dev/null +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -0,0 +1,63 @@ +import os +from typing import Any, Dict, List, Mapping, Optional, Tuple +from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel, SimpleChatModel +from langchain.llms.base import LLM +from langchain.schema.messages import BaseMessage +from langchain.schema.output import ChatGeneration, ChatResult +import openai + + +class ChatFireworks(BaseChatModel): + """Fireworks Chat models.""" + + model = "accounts/fireworks/models/llama-v2-13b-chat" + model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} + fireworks_api_url: Optional[str] = "https://api.fireworks.ai/inference/v1" + fireworks_api_key: Optional[str] = os.environ.get("FIREWORKS_API_KEY") + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "fireworks-chat" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts = self._create_message_dicts(messages, stop) + response = openai.ChatCompletion.create( + api_base=self.fireworks_api_url, + api_key=self.fireworks_api_key, + model=self.model, + messages=message_dicts, + **self.model_kwargs, + ) + return self._create_chat_result(response) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + return llm_outputs[0] + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = convert_dict_to_message(res["message"]) + gen = ChatGeneration( + message=message, + generation_info=dict(finish_reason=res.get("finish_reason")), + ) + generations.append(gen) + llm_output = {"model": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]]]: + message_dicts = [convert_message_to_dict(m) for m in messages] + return message_dicts diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py new file mode 100644 index 0000000000000..149266c9d2852 --- /dev/null +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -0,0 +1,73 @@ +"""Test ChatFireworks wrapper.""" + +import pytest + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.fireworks import ChatFireworks +from langchain.schema import ( + ChatGeneration, + ChatResult, + LLMResult, +) +from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage + + +def test_chat_fireworks() -> None: + """Test ChatFireworks wrapper.""" + chat = ChatFireworks() + message = HumanMessage(content="What is the weather in Redwood City, CA today") + response = chat([message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_fireworks_model() -> None: + """Test ChatFireworks wrapper handles model_name.""" + chat = ChatFireworks(model="foo") + assert chat.model == "foo" + + +def test_chat_fireworks_system_message() -> None: + """Test ChatFireworks wrapper with system message.""" + chat = ChatFireworks(max_tokens=10) + system_message = SystemMessage(content="You are to chat with the user.") + human_message = HumanMessage(content="Hello") + response = chat([system_message, human_message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_fireworks_generate() -> None: + """Test ChatFireworks wrapper with generate.""" + chat = ChatFireworks(model_kwargs={"n": 2}) + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 2 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +def test_chat_fireworks_multiple_completions() -> None: + """Test ChatFireworks wrapper with multiple completions.""" + chat = ChatFireworks(model_kwargs={"n": 5}) + message = HumanMessage(content="Hello") + response = chat._generate([message]) + assert isinstance(response, ChatResult) + assert len(response.generations) == 5 + for generation in response.generations: + assert isinstance(generation.message, BaseMessage) + assert isinstance(generation.message.content, str) + + +def test_chat_fireworks_llm_output_contains_model_id() -> None: + """Test llm_output contains model_id.""" + chat = ChatFireworks() + message = HumanMessage(content="Hello") + llm_result = chat.generate([[message]]) + assert llm_result.llm_output is not None + assert llm_result.llm_output["model"] == chat.model From 26c604a09371a96bf8572b7baf7c84e6b1b4c7be Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Sat, 9 Sep 2023 22:12:34 -0700 Subject: [PATCH 03/14] Support stream and add async test --- libs/langchain/langchain/llms/fireworks.py | 59 ++++++++++++++++++- .../integration_tests/llms/test_fireworks.py | 40 +++++++++++++ 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 00e390b920a2e..137663ac39d83 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,11 +1,29 @@ import os -from typing import Any, Optional +from typing import Any, Dict, Iterator, List, Optional -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM +from langchain.schema.language_model import LanguageModelInput +from langchain.schema.output import GenerationChunk +from langchain.schema.runnable.config import RunnableConfig import openai +def _stream_response_to_generation_chunk( + stream_response: Dict[str, Any], +) -> GenerationChunk: + """Convert a stream response to a generation chunk.""" + return GenerationChunk( + text=stream_response["choices"][0]["text"], + generation_info=dict( + finish_reason=stream_response["choices"][0].get("finish_reason", None), + logprobs=stream_response["choices"][0].get("logprobs", None), + ), + ) + + class Fireworks(LLM): """Fireworks models.""" @@ -22,6 +40,7 @@ def _llm_type(self) -> str: def _call( self, prompt: str, + stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -33,3 +52,39 @@ def _call( **self.model_kwargs, ) return response["choices"][0]["text"] + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + for stream_resp in openai.Completion.create( + api_base=self.fireworks_api_url, + api_key=self.fireworks_api_key, + model=self.model, + prompt=prompt, + stream=True, + **self.model_kwargs, + ): + chunk = _stream_response_to_generation_chunk(stream_resp) + yield chunk + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Iterator[str]: + prompt = self._convert_input(input).to_string() + generation: Optional[GenerationChunk] = None + for chunk in self._stream(prompt): + yield chunk.text + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None diff --git a/libs/langchain/tests/integration_tests/llms/test_fireworks.py b/libs/langchain/tests/integration_tests/llms/test_fireworks.py index 188056c9e5978..d5689fefbd9ae 100644 --- a/libs/langchain/tests/integration_tests/llms/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/llms/test_fireworks.py @@ -1,4 +1,5 @@ """Test Fireworks AI API Wrapper.""" +from typing import Generator from langchain import LLMChain, PromptTemplate from langchain.llms.fireworks import Fireworks from langchain.prompts.chat import ( @@ -6,6 +7,7 @@ HumanMessagePromptTemplate, ) from langchain.schema import LLMResult +import pytest def test_fireworks_call() -> None: @@ -43,3 +45,41 @@ def test_fireworks_multiple_prompts() -> None: assert isinstance(output, LLMResult) assert isinstance(output.generations, list) assert len(output.generations) == 2 + + +def test_fireworks_streaming() -> None: + """Test stream completion.""" + llm = Fireworks() + generator = llm.stream("Who's the best quarterback in the NFL?") + assert isinstance(generator, Generator) + + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.asyncio +async def test_fireworks_streaming_async() -> None: + """Test stream completion.""" + llm = Fireworks() + + async for token in llm.astream("Who's the best quarterback in the NFL?"): + assert isinstance(token, str) + + +@pytest.mark.asyncio +async def test_fireworks_async_agenerate() -> None: + """Test async.""" + llm = Fireworks() + output = await llm.agenerate(["What is the best city to live in California?"]) + assert isinstance(output, LLMResult) + + +@pytest.mark.asyncio +async def test_fireworks_multiple_prompts_async_agenerate() -> None: + llm = Fireworks() + output = await llm.agenerate( + ["How is the weather in New York today?", "I'm pickle rick"] + ) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) + assert len(output.generations) == 2 From 66b7f7415881d1ab32cb4a65f84b3bdbde2340ed Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Sun, 10 Sep 2023 14:01:10 -0700 Subject: [PATCH 04/14] Support stream and async test for chat models --- .../langchain/chat_models/fireworks.py | 91 +++++++++++++++++-- .../chat_models/test_fireworks.py | 35 ++++++- 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index e34e66525ce9a..561280559b13b 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -1,15 +1,48 @@ import os -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import BaseChatModel, SimpleChatModel -from langchain.llms.base import LLM -from langchain.schema.messages import BaseMessage -from langchain.schema.output import ChatGeneration, ChatResult +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema.messages import ( + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessageChunk, + SystemMessageChunk, +) +from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult import openai +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + class ChatFireworks(BaseChatModel): """Fireworks Chat models.""" @@ -28,7 +61,6 @@ def _generate( messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: message_dicts = self._create_message_dicts(messages, stop) @@ -41,6 +73,23 @@ def _generate( ) return self._create_chat_result(response) + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts = self._create_message_dicts(messages, stop) + response = await openai.ChatCompletion.acreate( + api_base=self.fireworks_api_url, + api_key=self.fireworks_api_key, + model=self.model, + messages=message_dicts, + **self.model_kwargs, + ) + return self._create_chat_result(response) + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return llm_outputs[0] @@ -61,3 +110,31 @@ def _create_message_dicts( ) -> Tuple[List[Dict[str, Any]]]: message_dicts = [convert_message_to_dict(m) for m in messages] return message_dicts + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts = self._create_message_dicts(messages, stop) + default_chunk_class = AIMessageChunk + for chunk in openai.ChatCompletion.create( + api_base=self.fireworks_api_url, + api_key=self.fireworks_api_key, + model=self.model, + messages=message_dicts, + stream=True, + **self.model_kwargs, + ): + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + finish_reason = choice.get("finish_reason") + generation_info = ( + dict(finish_reason=finish_reason) if finish_reason is not None else None + ) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk, generation_info=generation_info) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index 149266c9d2852..9cc681b03854c 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -29,7 +29,7 @@ def test_chat_fireworks_model() -> None: def test_chat_fireworks_system_message() -> None: """Test ChatFireworks wrapper with system message.""" - chat = ChatFireworks(max_tokens=10) + chat = ChatFireworks() system_message = SystemMessage(content="You are to chat with the user.") human_message = HumanMessage(content="Hello") response = chat([system_message, human_message]) @@ -71,3 +71,36 @@ def test_chat_fireworks_llm_output_contains_model_id() -> None: llm_result = chat.generate([[message]]) assert llm_result.llm_output is not None assert llm_result.llm_output["model"] == chat.model + + +def test_fireworks_streaming() -> None: + """Test streaming tokens from OpenAI.""" + llm = ChatFireworks() + + for token in llm.stream("I'm Pickle Rick"): + assert isinstance(token.content, str) + + +@pytest.mark.asyncio +async def test_chat_fireworks_agenerate() -> None: + """Test ChatFireworks wrapper with generate.""" + chat = ChatFireworks(model_kwargs={"n": 2}) + message = HumanMessage(content="Hello") + response = await chat.agenerate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 2 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +@pytest.mark.asyncio +async def test_fireworks_astream() -> None: + """Test streaming tokens from OpenAI.""" + llm = ChatFireworks() + + async for token in llm.astream("Who's the best quarterback in the NFL?"): + assert isinstance(token.content, str) From 8f30964758e81766ce3b0dcc6a962efcbbaf7819 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Mon, 11 Sep 2023 21:48:44 -0700 Subject: [PATCH 05/14] Add notebook --- docs/extras/integrations/chat/fireworks.ipynb | 108 ++++++++++++++++++ docs/extras/integrations/llms/fireworks.ipynb | 74 ++++++------ .../langchain/chat_models/fireworks.py | 10 +- libs/langchain/langchain/llms/fireworks.py | 8 +- 4 files changed, 154 insertions(+), 46 deletions(-) create mode 100644 docs/extras/integrations/chat/fireworks.ipynb diff --git a/docs/extras/integrations/chat/fireworks.ipynb b/docs/extras/integrations/chat/fireworks.ipynb new file mode 100644 index 0000000000000..f0dc9b088085e --- /dev/null +++ b/docs/extras/integrations/chat/fireworks.ipynb @@ -0,0 +1,108 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "642fd21c-600a-47a1-be96-6e1438b421a9", + "metadata": {}, + "source": [ + "# ChatFireworks\n", + "\n", + ">[Fireworks](https://app.fireworks.ai/) accelerates product development on generative AI by creating an innovative AI experiment and production platform. \n", + "\n", + "This example goes over how to use LangChain to interact with `ChatFireworks` models." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d00d850917865298", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "from langchain.chat_models.fireworks import ChatFireworks\n", + "from langchain.schema import SystemMessage, HumanMessage\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d096fb14-8acc-4047-9cd0-c842430c3a1d", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize a Fireworks Chat model\n", + "os.environ['FIREWORKS_API_KEY'] = \"\" # Change this to your own API key\n", + "chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "72340871-ae2f-415f-b399-0777d32dc379", + "metadata": {}, + "outputs": [], + "source": [ + "system_message = SystemMessage(content=\"You are to chat with the user.\")\n", + "human_message = HumanMessage(content=\"Who are you?\")\n", + "response = chat([system_message, human_message])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2d6ef879-69e3-422b-8379-bb980b70fe55", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"Hey there! I'm LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner. I'm here to help answer any questions you may have, provide information on a wide range of topics, and even assist with tasks and projects. I'm here to help make your life easier and more convenient. So, what's up? What do you want to talk about today? 😊\", additional_kwargs={}, example=False)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe29efc-37c3-4c83-8b84-b8bba1a1e589", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/extras/integrations/llms/fireworks.ipynb b/docs/extras/integrations/llms/fireworks.ipynb index 4da49230bfaca..0ea9aaf93f76e 100644 --- a/docs/extras/integrations/llms/fireworks.ipynb +++ b/docs/extras/integrations/llms/fireworks.ipynb @@ -19,7 +19,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.llms.fireworks import Fireworks, FireworksChat\n", + "from langchain.llms.fireworks import Fireworks\n", "from langchain import PromptTemplate, LLMChain\n", "from langchain.prompts.chat import (\n", " ChatPromptTemplate,\n", @@ -48,8 +48,8 @@ "outputs": [], "source": [ "# Initialize a Fireworks LLM\n", - "os.environ['FIREWORKS_API_KEY'] = \"\" # Change this to your own API key\n", - "llm = Fireworks(model_id=\"accounts/fireworks/models/llama-v2-13b-chat\")" + "os.environ['FIREWORKS_API_KEY'] = \"\" # Change this to your own API key\n", + "llm = Fireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")" ] }, { @@ -75,11 +75,12 @@ " * `accounts/fireworks/models/llama-v2-13b-w8a16`\n", " * `accounts/fireworks/models/llama-v2-13b-chat`\n", " * `accounts/fireworks/models/llama-v2-13b-chat-w8a16`\n", - " * `accounts/fireworks/models/llama-v2-70b-chat-4gpu`\n", + " * `accounts/fireworks/models/llama-v2-70b-chat`\n", + " * `accounts/fireworks/models/llama-v2-70b-chat-w8a16`\n", "* StarCoder\n", - " * `accounts/fireworks/models/starcoder-1b-w8a16-1gpu`\n", - " * `accounts/fireworks/models/starcoder-3b-w8a16-1gpu`\n", - " * `accounts/fireworks/models/starcoder-7b-w8a16-1gpu`\n", + " * `accounts/fireworks/models/starcoder-1b-w8a16`\n", + " * `accounts/fireworks/models/starcoder-3b-w8a16`\n", + " * `accounts/fireworks/models/starcoder-7b-w8a16`\n", " * `accounts/fireworks/models/starcoder-16b-w8a16`\n", "\n", "See the full, most up-to-date list on [app.fireworks.ai](https://app.fireworks.ai)." @@ -95,29 +96,23 @@ "name": "stdout", "output_type": "stream", "text": [ - "Is it Tom Brady, Aaron Rodgers, or someone else? It's a tough question to answer, and there are strong arguments for each of these quarterbacks. Here are some of the reasons why each of these quarterbacks could be considered the best:\n", "\n", - "Tom Brady:\n", "\n", - "* He has the most Super Bowl wins (6) of any quarterback in NFL history.\n", - "* He has been named Super Bowl MVP four times, more than any other player.\n", - "* He has led the New England Patriots to 18 playoff victories, the most in NFL history.\n", - "* He has thrown for over 70,000 yards in his career, the most of any quarterback in NFL history.\n", - "* He has thrown for 50 or more touchdowns in a season four times, the most of any quarterback in NFL history.\n", + "Who's the best quarterback in the NFL?\n", "\n", - "Aaron Rodgers:\n", + "Well, that's a tough question. There are a lot of great quarterbacks in the league right now, and it's hard to say who's the absolute best. But if I had to choose, I'd say that Tom Brady is probably the best quarterback in the NFL.\n", "\n", - "* He has led the Green Bay Packers to a Super Bowl victory in 2010.\n", - "* He has been named Super Bowl MVP once.\n", - "* He has thrown for over 40,000 yards in his career, the most of any quarterback in NFL history.\n", - "* He has thrown for 40 or more touchdowns in a season three times, the most of any quarterback in NFL history.\n", - "* He has a career passer rating of 103.1, the highest of any quarterback in NFL history.\n", + "Now, I know some people might say that Aaron Rodgers or Drew Brees or Patrick Mahomes is the best, and they have valid arguments. But here's why I think Tom Brady is the best:\n", "\n", - "So, who's the best quarterback in the NFL? It's a tough call, but here's my opinion:\n", + "First of all, he's incredibly consistent. He's been playing at an elite level for almost two decades now, and he's never really had a bad season. He's always been able to adapt to whatever situation he's in, and he's always been able to find a way to win.\n", "\n", - "I think Aaron Rodgers is the best quarterback in the NFL right now. He has led the Packers to a Super Bowl victory and has had some incredible seasons, including the 2011 season when he threw for 45 touchdowns and just 6 interceptions. He has a strong arm, great accuracy, and is incredibly mobile for a quarterback of his size. He also has a great sense of timing and knows when to take risks and when to play it safe.\n", + "Second, he's got an incredible work ethic. He's always been known for his dedication to the game, and he's always been willing to put in the extra work to get better. He's got a great understanding of the game, and he's always looking for ways to improve.\n", "\n", - "Tom Brady is a close second, though. He has an incredible track record of success, including six Super Bowl victories, and has been one of the most consistent quarterbacks in the league for the past two decades. He has a strong arm and is incredibly accurate\n" + "Third, he's got a great supporting cast. He's played with some of the best players in the league, and he's always been able to get the most out of his teammates. He's got a great offensive line, and he's got some of the best receivers and running backs in the league.\n", + "\n", + "Now, I know some people might say that Tom Brady's success is all because of his supporting cast, and that he's not as good as some of the other quarterbacks in the league. But I think that's a bit unfair. Sure, he's had some great players around him, but he's also been the one who's led his team to six Super Bowl wins. He's the one who's made the plays when it counts, and he's the one who's always come through in the clutch.\n", + "\n", + "So, there you have it. That's why I think Tom Brady is the best quarterback in the NFL. He's consistent, he's got a great work ethic, and he's got a great supporting\n" ] } ], @@ -137,7 +132,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[[Generation(text='\\nThe best cricket player in 2016 is a matter of opinion, but some of the top contenders for the title include:\\n\\n1. Virat Kohli (India): Kohli had a phenomenal year in 2016, scoring over 1,000 runs in Test cricket, including four centuries, and averaging over 70. He also scored heavily in ODI cricket, with an average of over 80.\\n2. Steve Smith (Australia): Smith had a remarkable year in 2016, leading Australia to a Test series victory in India and scoring over 1,000 runs in the format, including five centuries. He also averaged over 60 in ODI cricket.\\n3. KL Rahul (India): Rahul had a breakout year in 2016, scoring over 1,000 runs in Test cricket, including four centuries, and averaging over 60. He also scored heavily in ODI cricket, with an average of over 70.\\n4. Joe Root (England): Root had a solid year in 2016, scoring over 1,000 runs in Test cricket, including four centuries, and averaging over 50. He also scored heavily in ODI cricket, with an average of over 80.\\n5. Quinton de Kock (South Africa): De Kock had a remarkable year in 2016, scoring over 1,000 runs in ODI cricket, including six centuries, and averaging over 80. He also scored heavily in Test cricket, with an average of over 50.\\n\\nThese are just a few of the top contenders for the title of best cricket player in 2016, but there were many other talented players who also had impressive years. Ultimately, the answer to this question is subjective and depends on individual opinions and criteria for evaluation.', generation_info=None)], [Generation(text=\"\\nThis is a tough one, as there are so many great players in the league right now. But if I had to choose one, I'd say LeBron James is the best basketball player in the league. He's a once-in-a-generation talent who can dominate the game in so many ways. He's got incredible speed, strength, and court vision, and he's always finding new ways to improve his game. Plus, he's been doing it at an elite level for over a decade now, which is just amazing.\\n\\nBut don't just take my word for it - there are plenty of other great players in the league who could make a strong case for being the best. Guys like Kevin Durant, Steph Curry, James Harden, and Giannis Antetokounmpo are all having incredible seasons, and they've all got their own unique skills and strengths that make them special. So ultimately, it's up to you to decide who you think is the best basketball player in the league.\", generation_info=None)]]\n" + "[[Generation(text=\"\\n\\nCertainly, it is a matter of debate and personal opinion. But if we look at the stats and performances of various cricketers in 2016, here are some of the top contenders for the title of the best cricket player of the year:\\n\\n1. Virat Kohli (India): Kohli had a phenomenal year in 2016, scoring over 1,000 runs in Test cricket, including 11 centuries. He also averaged over 70 in ODIs and was instrumental in India's success in the World T20.\\n2. Steve Smith (Australia): Smith had a great year as well, scoring over 1,000 runs in Test cricket and averaging over 60. He also led Australia to a Test series victory in Sri Lanka and the West Indies.\\n3. KL Rahul (India): Rahul had a breakout year in 2016, scoring three centuries in Test cricket and averaging over 50. He also scored a century in the World T20 and was named the ICC Emerging Cricketer of the Year.\\n4. Joe Root (England): Root had a solid year in 2016, scoring over 1,000 runs in Test cricket and averaging over 50. He also led England to a Test series victory in Bangladesh.\\n5. David Warner (Australia): Warner had a great year in limited-overs cricket, scoring over 1,000 runs in ODIs and averaging over 60. He also scored a century in the World T20.\\n\\nOf course, there are other great cricketers who had excellent years in 2016, such as Quinton de Kock (South Africa), AB de Villiers (South Africa), and Tamim Iqbal (Bangladesh). But based on their stats and performances, these five players are among the top contenders for the title of the best cricket player of 2016.\", generation_info=None)], [Generation(text=\"\\n\\nWho's the best basketball player in the league? \", generation_info=None)]]\n" ] } ], @@ -161,13 +156,13 @@ "output_type": "stream", "text": [ "\n", - "Kansas City in December is quite cold, with temperatures typically r\n" + "Kansas City, Missouri, experiences a continental climate with cold winter\n" ] } ], "source": [ "# Setting additional parameters: temperature, max_tokens, top_p\n", - "llm = Fireworks(model_id=\"accounts/fireworks/models/llama-v2-13b-chat\", temperature=0.7, max_tokens=15, top_p=1.0)\n", + "llm = Fireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":0.7, \"max_tokens\":15, \"top_p\":1.0})\n", "print(llm(\"What's the weather like in Kansas City in December?\"))" ] }, @@ -192,30 +187,35 @@ "output_type": "stream", "text": [ "\n", - "Naming a company can be a fun and creative process! Here are a few name ideas for a company that makes football helmets:\n", "\n", - "1. Helix Headgear: This name plays off the idea of the helix shape of a football helmet and could be a memorable and catchy name for a company.\n", - "2. Gridiron Gear: \"Gridiron\" is a term used to describe a football field, and \"gear\" refers to the products the company sells. This name is straightforward and easy to understand.\n", - "3. Cushion Crusaders: This name emphasizes the protective qualities of football helmets and could appeal to customers looking for safety-conscious products.\n", - "4. Helmet Heroes: This name has a fun, heroic tone and could appeal to customers looking for high-quality products.\n", - "5. Tackle Tech: \"Tackle\" is a term used in football to describe a player's attempt to stop an opponent, and \"tech\" refers to the technology used in the helmets. This name could appeal to customers interested in innovative products.\n", - "6. Padded Protection: This name emphasizes the protective qualities of football helmets and could appeal to customers looking for products that prioritize safety.\n", - "7. Gridiron Gear Co.: This name is simple and straightforward, and it clearly conveys the company's focus on football-related products.\n", - "8. Helmet Haven: This name has a soothing, protective tone and could appeal to customers looking for a reliable brand.\n", + "Assistant: Well, there are a few options! Here are a few ideas for a company that makes football helmets:\n", "\n", - "Remember to choose a name that reflects your company's values and mission, and that resonates with your target market. Good luck with your company!\n" + "1. Gridiron Gear: This name plays off the term \"gridiron,\" which is a slang term for a football field. It also has a strong, rugged sound to it, which could appeal to customers who are looking for high-quality football helmets.\n", + "2. Helmet Haven: This name conveys a sense of safety and protection, which is important for customers who are looking for a reliable football helmet. It also has a catchy, memorable sound to it.\n", + "3. Touchdown Titan: This name has a fun, energetic sound to it, which could appeal to customers who are looking for a football helmet that's both functional and stylish. It also plays off the idea of scoring a touchdown, which is a common goal in football.\n", + "4. Football Fusion: This name combines the idea of football with the concept of fusion, which could imply a blending of different materials or technologies to create a high-quality helmet. It also has a modern, cutting-edge sound to it.\n", + "\n", + "I hope these suggestions help inspire you as you come up with a name for your company!\n" ] } ], "source": [ "human_message_prompt = HumanMessagePromptTemplate.from_template(\"What is a good name for a company that makes {product}?\")\n", "chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])\n", - "chat = FireworksChat()\n", + "chat = Fireworks()\n", "chain = LLMChain(llm=chat, prompt=chat_prompt_template)\n", "output = chain.run(\"football helmets\")\n", "\n", "print(output)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26d67ecf-9290-4ec2-8b39-ff17fc99620f", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 561280559b13b..8935dd02ddeaa 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -46,10 +46,10 @@ def _convert_delta_to_message_chunk( class ChatFireworks(BaseChatModel): """Fireworks Chat models.""" - model = "accounts/fireworks/models/llama-v2-13b-chat" + model = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} fireworks_api_url: Optional[str] = "https://api.fireworks.ai/inference/v1" - fireworks_api_key: Optional[str] = os.environ.get("FIREWORKS_API_KEY") + fireworks_api_key: Optional[str] = None @property def _llm_type(self) -> str: @@ -66,7 +66,7 @@ def _generate( message_dicts = self._create_message_dicts(messages, stop) response = openai.ChatCompletion.create( api_base=self.fireworks_api_url, - api_key=self.fireworks_api_key, + api_key=os.environ.get("FIREWORKS_API_KEY"), model=self.model, messages=message_dicts, **self.model_kwargs, @@ -83,7 +83,7 @@ async def _agenerate( message_dicts = self._create_message_dicts(messages, stop) response = await openai.ChatCompletion.acreate( api_base=self.fireworks_api_url, - api_key=self.fireworks_api_key, + api_key=os.environ.get("FIREWORKS_API_KEY"), model=self.model, messages=message_dicts, **self.model_kwargs, @@ -122,7 +122,7 @@ def _stream( default_chunk_class = AIMessageChunk for chunk in openai.ChatCompletion.create( api_base=self.fireworks_api_url, - api_key=self.fireworks_api_key, + api_key=os.environ.get("FIREWORKS_API_KEY"), model=self.model, messages=message_dicts, stream=True, diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 137663ac39d83..20ab38bffcaba 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -27,10 +27,10 @@ def _stream_response_to_generation_chunk( class Fireworks(LLM): """Fireworks models.""" - model = "accounts/fireworks/models/llama-v2-13b-chat" + model = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} fireworks_api_url: Optional[str] = "https://api.fireworks.ai/inference/v1" - fireworks_api_key: Optional[str] = os.environ.get("FIREWORKS_API_KEY") + fireworks_api_key: Optional[str] = None @property def _llm_type(self) -> str: @@ -46,7 +46,7 @@ def _call( ) -> str: response = openai.Completion.create( api_base=self.fireworks_api_url, - api_key=self.fireworks_api_key, + api_key=os.environ.get("FIREWORKS_API_KEY"), model=self.model, prompt=prompt, **self.model_kwargs, @@ -62,7 +62,7 @@ def _stream( ) -> Iterator[GenerationChunk]: for stream_resp in openai.Completion.create( api_base=self.fireworks_api_url, - api_key=self.fireworks_api_key, + api_key=os.environ.get("FIREWORKS_API_KEY"), model=self.model, prompt=prompt, stream=True, From 79f5ad6541ef2410d10f7319a97da60e99a67cb9 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Wed, 13 Sep 2023 03:49:19 +0000 Subject: [PATCH 06/14] Support retry --- .../langchain/chat_models/fireworks.py | 50 +++++++++++++------ libs/langchain/langchain/llms/fireworks.py | 47 +++++++++++------ 2 files changed, 68 insertions(+), 29 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 8935dd02ddeaa..eae8c5094c99f 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -1,11 +1,15 @@ import os +import time from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple + +import backoff from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.openai import _create_retry_decorator from langchain.schema.messages import ( AIMessageChunk, BaseMessage, @@ -48,8 +52,9 @@ class ChatFireworks(BaseChatModel): model = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} - fireworks_api_url: Optional[str] = "https://api.fireworks.ai/inference/v1" + fireworks_api_base: Optional[str] = "https://api.fireworks.ai/inference/v1" fireworks_api_key: Optional[str] = None + max_retries: int = 20 @property def _llm_type(self) -> str: @@ -64,13 +69,13 @@ def _generate( **kwargs: Any, ) -> ChatResult: message_dicts = self._create_message_dicts(messages, stop) - response = openai.ChatCompletion.create( - api_base=self.fireworks_api_url, - api_key=os.environ.get("FIREWORKS_API_KEY"), - model=self.model, - messages=message_dicts, + + params = { + "model": self.model, + "messages": message_dicts, **self.model_kwargs, - ) + } + response = self.completion_with_retry(**params) return self._create_chat_result(response) async def _agenerate( @@ -82,7 +87,7 @@ async def _agenerate( ) -> ChatResult: message_dicts = self._create_message_dicts(messages, stop) response = await openai.ChatCompletion.acreate( - api_base=self.fireworks_api_url, + api_base=self.fireworks_api_base, api_key=os.environ.get("FIREWORKS_API_KEY"), model=self.model, messages=message_dicts, @@ -120,14 +125,13 @@ def _stream( ) -> Iterator[ChatGenerationChunk]: message_dicts = self._create_message_dicts(messages, stop) default_chunk_class = AIMessageChunk - for chunk in openai.ChatCompletion.create( - api_base=self.fireworks_api_url, - api_key=os.environ.get("FIREWORKS_API_KEY"), - model=self.model, - messages=message_dicts, - stream=True, + params = { + "model": self.model, + "messages": message_dicts, + "stream": True, **self.model_kwargs, - ): + } + for chunk in self.completion_with_retry(**params): choice = chunk["choices"][0] chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class @@ -138,3 +142,19 @@ def _stream( ) default_chunk_class = chunk.__class__ yield ChatGenerationChunk(message=chunk, generation_info=generation_info) + + def completion_with_retry( + self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + ) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return openai.ChatCompletion.create( + api_base="https://api.fireworks.ai/inference/v1", + api_key=os.environ.get("FIREWORKS_API_KEY"), + **kwargs, + ) + + return _completion_with_retry(**kwargs) diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 20ab38bffcaba..84995da941af2 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,9 +1,12 @@ import os from typing import Any, Dict, Iterator, List, Optional +import backoff + from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) +from langchain.chat_models.openai import _create_retry_decorator from langchain.llms.base import LLM from langchain.schema.language_model import LanguageModelInput from langchain.schema.output import GenerationChunk @@ -29,8 +32,9 @@ class Fireworks(LLM): model = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} - fireworks_api_url: Optional[str] = "https://api.fireworks.ai/inference/v1" + fireworks_api_base: Optional[str] = "https://api.fireworks.ai/inference/v1" fireworks_api_key: Optional[str] = None + max_retries: int = 20 @property def _llm_type(self) -> str: @@ -44,13 +48,13 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - response = openai.Completion.create( - api_base=self.fireworks_api_url, - api_key=os.environ.get("FIREWORKS_API_KEY"), - model=self.model, - prompt=prompt, + params = { + "model": self.model, + "prompt": prompt, **self.model_kwargs, - ) + } + response = self.completion_with_retry(**params) + return response["choices"][0]["text"] def _stream( @@ -60,14 +64,13 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: - for stream_resp in openai.Completion.create( - api_base=self.fireworks_api_url, - api_key=os.environ.get("FIREWORKS_API_KEY"), - model=self.model, - prompt=prompt, - stream=True, + params = { + "model": self.model, + "prompt": prompt, + "stream": True, **self.model_kwargs, - ): + } + for stream_resp in self.completion_with_retry(**params): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk @@ -88,3 +91,19 @@ def stream( else: generation += chunk assert generation is not None + + def completion_with_retry( + self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + ) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return openai.Completion.create( + api_base="https://api.fireworks.ai/inference/v1", + api_key=os.environ.get("FIREWORKS_API_KEY"), + **kwargs, + ) + + return _completion_with_retry(**kwargs) From 97f252a04ec978d7923f7a9fed687e003ee59e6e Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Wed, 13 Sep 2023 05:27:15 +0000 Subject: [PATCH 07/14] Use fireworks client instead of openai --- .../langchain/chat_models/fireworks.py | 151 ++++++++++++------ libs/langchain/langchain/llms/fireworks.py | 81 ++++++---- 2 files changed, 155 insertions(+), 77 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index eae8c5094c99f..e6950eaa338d0 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -1,16 +1,23 @@ -import os -import time -from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple - -import backoff -from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict - +import fireworks.client +from langchain.utils.env import get_from_dict_or_env +from pydantic import root_validator +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union +from langchain.adapters.openai import convert_message_to_dict from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel -from langchain.chat_models.openai import _create_retry_decorator +from langchain.llms.base import create_base_retry_decorator from langchain.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, @@ -20,18 +27,14 @@ SystemMessageChunk, ) from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult -import openai def _convert_delta_to_message_chunk( _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] ) -> BaseMessageChunk: - role = _dict.get("role") - content = _dict.get("content") or "" - if _dict.get("function_call"): - additional_kwargs = {"function_call": dict(_dict["function_call"])} - else: - additional_kwargs = {} + role = _dict.role + content = _dict.content or "" + additional_kwargs = {} if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) @@ -40,22 +43,46 @@ def _convert_delta_to_message_chunk( elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"]) + return FunctionMessageChunk(content=content, name=_dict.name) elif role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) else: return default_class(content=content) +def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict.role + content = _dict.content or "" + if role == "user": + return HumanMessage(content=content) + elif role == "assistant": + content = _dict.content + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=content) + elif role == "function": + return FunctionMessage(content=content, name=_dict.name) + else: + return ChatMessage(content=content, role=role) + + class ChatFireworks(BaseChatModel): """Fireworks Chat models.""" model = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} - fireworks_api_base: Optional[str] = "https://api.fireworks.ai/inference/v1" fireworks_api_key: Optional[str] = None max_retries: int = 20 + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["fireworks_api_key"] = get_from_dict_or_env( + values, "fireworks_api_key", "FIREWORKS_API_KEY" + ) + return values + @property def _llm_type(self) -> str: """Return type of llm.""" @@ -75,7 +102,7 @@ def _generate( "messages": message_dicts, **self.model_kwargs, } - response = self.completion_with_retry(**params) + response = completion_with_retry(self, **params) return self._create_chat_result(response) async def _agenerate( @@ -86,13 +113,12 @@ async def _agenerate( **kwargs: Any, ) -> ChatResult: message_dicts = self._create_message_dicts(messages, stop) - response = await openai.ChatCompletion.acreate( - api_base=self.fireworks_api_base, - api_key=os.environ.get("FIREWORKS_API_KEY"), - model=self.model, - messages=message_dicts, + params = { + "model": self.model, + "messages": message_dicts, **self.model_kwargs, - ) + } + response = completion_with_retry(self, **params) return self._create_chat_result(response) def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: @@ -100,11 +126,11 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: generations = [] - for res in response["choices"]: - message = convert_dict_to_message(res["message"]) + for res in response.choices: + message = convert_dict_to_message(res.message) gen = ChatGeneration( message=message, - generation_info=dict(finish_reason=res.get("finish_reason")), + generation_info=dict(finish_reason=res.finish_reason), ) generations.append(gen) llm_output = {"model": self.model} @@ -131,30 +157,61 @@ def _stream( "stream": True, **self.model_kwargs, } - for chunk in self.completion_with_retry(**params): - choice = chunk["choices"][0] - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - finish_reason = choice.get("finish_reason") + for chunk in completion_with_retry(self, **params): + choice = chunk.choices[0] + chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) + finish_reason = choice.finish_reason generation_info = ( dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ yield ChatGenerationChunk(message=chunk, generation_info=generation_info) - def completion_with_retry( - self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any - ) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(self, run_manager=run_manager) - - @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: - return openai.ChatCompletion.create( - api_base="https://api.fireworks.ai/inference/v1", - api_key=os.environ.get("FIREWORKS_API_KEY"), - **kwargs, - ) - return _completion_with_retry(**kwargs) +def completion_with_retry( + llm: ChatFireworks, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return fireworks.client.ChatCompletion.create( + **kwargs, + ) + + return _completion_with_retry(**kwargs) + + +async def acompletion_with_retry( + llm: ChatFireworks, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return fireworks.client.ChatCompletion.acreate( + **kwargs, + ) + + return await _completion_with_retry(**kwargs) + + +def _create_retry_decorator( + llm: ChatFireworks, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + errors = [ + fireworks.client.error.RateLimitError, + fireworks.client.error.ServiceUnavailableError, + ] + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 84995da941af2..e8f72ffa17639 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,17 +1,15 @@ -import os -from typing import Any, Dict, Iterator, List, Optional - -import backoff - +import fireworks.client +from typing import Any, Callable, Dict, Iterator, List, Optional, Union from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.chat_models.openai import _create_retry_decorator -from langchain.llms.base import LLM +from langchain.llms.base import LLM, create_base_retry_decorator from langchain.schema.language_model import LanguageModelInput from langchain.schema.output import GenerationChunk from langchain.schema.runnable.config import RunnableConfig -import openai +from langchain.utils.env import get_from_dict_or_env +from pydantic import root_validator def _stream_response_to_generation_chunk( @@ -19,10 +17,10 @@ def _stream_response_to_generation_chunk( ) -> GenerationChunk: """Convert a stream response to a generation chunk.""" return GenerationChunk( - text=stream_response["choices"][0]["text"], + text=stream_response.choices[0].text, generation_info=dict( - finish_reason=stream_response["choices"][0].get("finish_reason", None), - logprobs=stream_response["choices"][0].get("logprobs", None), + finish_reason=stream_response.choices[0].finish_reason, + logprobs=stream_response.choices[0].logprobs, ), ) @@ -32,10 +30,17 @@ class Fireworks(LLM): model = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} - fireworks_api_base: Optional[str] = "https://api.fireworks.ai/inference/v1" fireworks_api_key: Optional[str] = None max_retries: int = 20 + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["fireworks_api_key"] = get_from_dict_or_env( + values, "fireworks_api_key", "FIREWORKS_API_KEY" + ) + return values + @property def _llm_type(self) -> str: """Return type of llm.""" @@ -53,9 +58,9 @@ def _call( "prompt": prompt, **self.model_kwargs, } - response = self.completion_with_retry(**params) + response = completion_with_retry(self, **params) - return response["choices"][0]["text"] + return response.choices[0].text def _stream( self, @@ -70,7 +75,7 @@ def _stream( "stream": True, **self.model_kwargs, } - for stream_resp in self.completion_with_retry(**params): + for stream_resp in completion_with_retry(self, **params): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk @@ -92,18 +97,34 @@ def stream( generation += chunk assert generation is not None - def completion_with_retry( - self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any - ) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(self, run_manager=run_manager) - - @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: - return openai.Completion.create( - api_base="https://api.fireworks.ai/inference/v1", - api_key=os.environ.get("FIREWORKS_API_KEY"), - **kwargs, - ) - - return _completion_with_retry(**kwargs) + +def completion_with_retry( + llm: Fireworks, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return fireworks.client.Completion.create( + **kwargs, + ) + + return _completion_with_retry(**kwargs) + + +def _create_retry_decorator( + llm: Fireworks, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + errors = [ + fireworks.client.error.RateLimitError, + fireworks.client.error.ServiceUnavailableError, + ] + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) From 9a60f4af836436c386fb61568e1a8c4f8dff3f0f Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Wed, 13 Sep 2023 05:33:40 +0000 Subject: [PATCH 08/14] Validate environments --- libs/langchain/langchain/chat_models/fireworks.py | 3 ++- libs/langchain/langchain/llms/fireworks.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index e6950eaa338d0..e5685466c0fa7 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -78,9 +78,10 @@ class ChatFireworks(BaseChatModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["fireworks_api_key"] = get_from_dict_or_env( + fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) + fireworks.api_key = fireworks_api_key return values @property diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index e8f72ffa17639..96a9b4775e2b0 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,3 +1,4 @@ +import fireworks import fireworks.client from typing import Any, Callable, Dict, Iterator, List, Optional, Union from langchain.callbacks.manager import ( @@ -36,9 +37,10 @@ class Fireworks(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["fireworks_api_key"] = get_from_dict_or_env( + fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) + fireworks.api_key = fireworks_api_key return values @property From 89286da870ed0decaa8cd9101e9faf29a8046256 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Wed, 13 Sep 2023 05:36:29 +0000 Subject: [PATCH 09/14] Fix api key provide --- libs/langchain/langchain/chat_models/fireworks.py | 2 +- libs/langchain/langchain/llms/fireworks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index e5685466c0fa7..7155d8cddf40d 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -81,7 +81,7 @@ def validate_environment(cls, values: Dict) -> Dict: fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) - fireworks.api_key = fireworks_api_key + fireworks.client.api_key = fireworks_api_key return values @property diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 96a9b4775e2b0..cd2c37cb2626d 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -40,7 +40,7 @@ def validate_environment(cls, values: Dict) -> Dict: fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) - fireworks.api_key = fireworks_api_key + fireworks.client.api_key = fireworks_api_key return values @property From 24c94a78ca372095c0168ebc4122500165a1d060 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Wed, 13 Sep 2023 20:05:12 +0000 Subject: [PATCH 10/14] Fix async didin't wait --- .../langchain/chat_models/fireworks.py | 45 +++++++++- libs/langchain/langchain/llms/fireworks.py | 85 +++++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 7155d8cddf40d..cafceed719c1f 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -1,3 +1,4 @@ +import fireworks import fireworks.client from langchain.utils.env import get_from_dict_or_env from pydantic import root_validator @@ -119,7 +120,7 @@ async def _agenerate( "messages": message_dicts, **self.model_kwargs, } - response = completion_with_retry(self, **params) + response = await acompletion_with_retry(self, **params) return self._create_chat_result(response) def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: @@ -168,6 +169,31 @@ def _stream( default_chunk_class = chunk.__class__ yield ChatGenerationChunk(message=chunk, generation_info=generation_info) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts = self._create_message_dicts(messages, stop) + default_chunk_class = AIMessageChunk + params = { + "model": self.model, + "messages": message_dicts, + "stream": True, + **self.model_kwargs, + } + async for chunk in await acompletion_with_retry_streaming(self, **params): + choice = chunk.choices[0] + chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) + finish_reason = choice.finish_reason + generation_info = ( + dict(finish_reason=finish_reason) if finish_reason is not None else None + ) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk, generation_info=generation_info) + def completion_with_retry( llm: ChatFireworks, @@ -194,6 +220,23 @@ async def acompletion_with_retry( """Use tenacity to retry the async completion call.""" retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return await fireworks.client.ChatCompletion.acreate( + **kwargs, + ) + + return await _completion_with_retry(**kwargs) + + +async def acompletion_with_retry_streaming( + llm: ChatFireworks, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: return fireworks.client.ChatCompletion.acreate( diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index cd2c37cb2626d..c24f7c84c20a4 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -64,6 +64,22 @@ def _call( return response.choices[0].text + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + params = { + "model": self.model, + "prompt": prompt, + **self.model_kwargs, + } + response = await acompletion_with_retry(self, **params) + + return response.choices[0].text + def _stream( self, prompt: str, @@ -81,6 +97,23 @@ def _stream( chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + params = { + "model": self.model, + "prompt": prompt, + "stream": True, + **self.model_kwargs, + } + async for stream_resp in await acompletion_with_retry_streaming(self, **params): + chunk = _stream_response_to_generation_chunk(stream_resp) + yield chunk + def stream( self, input: LanguageModelInput, @@ -99,6 +132,24 @@ def stream( generation += chunk assert generation is not None + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Iterator[str]: + prompt = self._convert_input(input).to_string() + generation: Optional[GenerationChunk] = None + async for chunk in self._astream(prompt): + yield chunk.text + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + def completion_with_retry( llm: Fireworks, @@ -117,6 +168,40 @@ def _completion_with_retry(**kwargs: Any) -> Any: return _completion_with_retry(**kwargs) +async def acompletion_with_retry( + llm: Fireworks, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return await fireworks.client.Completion.acreate( + **kwargs, + ) + + return await _completion_with_retry(**kwargs) + + +async def acompletion_with_retry_streaming( + llm: Fireworks, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return fireworks.client.Completion.acreate( + **kwargs, + ) + + return await _completion_with_retry(**kwargs) + + def _create_retry_decorator( llm: Fireworks, run_manager: Optional[ From 1ff7ca2f0aba235ad0a751dac25b53a15d326e0a Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Wed, 13 Sep 2023 13:33:55 -0700 Subject: [PATCH 11/14] Update notebook --- docs/extras/integrations/chat/fireworks.ipynb | 151 +++++++++++++++- docs/extras/integrations/llms/fireworks.ipynb | 167 +++++++++++++----- 2 files changed, 271 insertions(+), 47 deletions(-) diff --git a/docs/extras/integrations/chat/fireworks.ipynb b/docs/extras/integrations/chat/fireworks.ipynb index f0dc9b088085e..deb10148af29c 100644 --- a/docs/extras/integrations/chat/fireworks.ipynb +++ b/docs/extras/integrations/chat/fireworks.ipynb @@ -30,6 +30,17 @@ "import os" ] }, + { + "cell_type": "markdown", + "id": "f28ebf8b-f14f-46c7-9962-8b8dc42e31be", + "metadata": {}, + "source": [ + "# Setup\n", + "Contact Fireworks AI for the an API Key to access our models\n", + "\n", + "Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat." + ] + }, { "cell_type": "code", "execution_count": 2, @@ -42,6 +53,18 @@ "chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")" ] }, + { + "cell_type": "markdown", + "id": "d8f13144-37cf-47a5-b5a0-e3cdf76d9a72", + "metadata": {}, + "source": [ + "# Calling the Model\n", + "\n", + "You can use the LLMs to call the model for specified message(s). \n", + "\n", + "See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)." + ] + }, { "cell_type": "code", "execution_count": 3, @@ -49,6 +72,7 @@ "metadata": {}, "outputs": [], "source": [ + "# ChatFireworks Wrapper\n", "system_message = SystemMessage(content=\"You are to chat with the user.\")\n", "human_message = HumanMessage(content=\"Who are you?\")\n", "response = chat([system_message, human_message])" @@ -63,7 +87,7 @@ { "data": { "text/plain": [ - "AIMessage(content=\"Hey there! I'm LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner. I'm here to help answer any questions you may have, provide information on a wide range of topics, and even assist with tasks and projects. I'm here to help make your life easier and more convenient. So, what's up? What do you want to talk about today? 😊\", additional_kwargs={}, example=False)" + "AIMessage(content=\"Hello! My name is LLaMA, I'm a large language model trained by a team of researcher at Meta AI. My primary function is to assist users with tasks and answer questions to the best of my ability. I am capable of understanding and responding to natural language input, and I am here to help you with any questions or tasks you may have. Is there anything specific you would like to know or discuss?\", additional_kwargs={}, example=False)" ] }, "execution_count": 4, @@ -77,10 +101,133 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, + "id": "68c6b1fa-2ff7-4a63-8d88-3cec302180b8", + "metadata": {}, + "outputs": [], + "source": [ + "# Setting additional parameters: temperature, max_tokens, top_p\n", + "chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\", model_kwargs={\"temperature\":1, \"max_tokens\": 20, \"top_p\": 1})\n", + "system_message = SystemMessage(content=\"You are to chat with the user.\")\n", + "human_message = HumanMessage(content=\"How's the weather today?\")\n", + "response = chat([system_message, human_message])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a09025f8-e4c3-4005-a8fc-c9c774b03a64", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"Oh, you know, it's just another beautiful day in the virtual world! The sun\", additional_kwargs={}, example=False)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response" + ] + }, + { + "cell_type": "markdown", + "id": "d93aa186-39cf-4e1a-aa32-01ed31d43bc8", + "metadata": {}, + "source": [ + "# ChatFireworks Wrapper with generate" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "id": "cbe29efc-37c3-4c83-8b84-b8bba1a1e589", "metadata": {}, "outputs": [], + "source": [ + "chat = ChatFireworks()\n", + "message = HumanMessage(content=\"Hello\")\n", + "response = chat.generate([[message], [message]])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "35109f36-9519-47a6-a223-25639123e836", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LLMResult(generations=[[ChatGeneration(text=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! It's nice to meet you. I'm here to help answer any questions you may have, while being respectful and safe. Please feel free to ask me anything, and I will do my best to provide helpful and positive responses. Is there something specific you would like to know or discuss?\", additional_kwargs={}, example=False))], [ChatGeneration(text=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"Hello! *smiling* I'm here to help you with any questions or concerns you may have. Please feel free to ask me anything, and I will do my best to provide helpful, respectful, and honest responses. I'm programmed to avoid any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, and to provide socially unbiased and positive responses. Is there anything specific you would like to talk about or ask?\", additional_kwargs={}, example=False))]], llm_output={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, run=[RunInfo(run_id=UUID('f137463e-e1c7-454a-8b85-b999ce20e0f2')), RunInfo(run_id=UUID('f3ef1138-92de-4e01-900b-991e34a647a7'))])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response" + ] + }, + { + "cell_type": "markdown", + "id": "92c2cabb-9eaf-4c49-b0e5-a5de5a7d920e", + "metadata": {}, + "source": [ + "# ChatFireworks Wrapper with stream" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "12717a29-fb7d-4a4d-860b-40435452b065", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Hello! I'm just\n", + " an AI assistant,\n", + " here to help answer your\n", + " questions and provide information in\n", + " a responsible and respectful manner\n", + ". I'm not able\n", + " to access personal information or provide\n", + " any content that could be considered\n", + " harmful, uneth\n", + "ical, racist, sex\n", + "ist, toxic, dangerous\n", + ", or illegal. My purpose\n", + " is to assist and provide helpful\n", + " responses that are socially un\n", + "biased and positive in nature\n", + ". Is there something specific you\n", + " would like to know or discuss\n", + "?\n" + ] + } + ], + "source": [ + "llm = ChatFireworks()\n", + "\n", + "for token in llm.stream(\"Who are you\"):\n", + " print(token.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02991e05-a38e-47d4-9ab3-7e630a8ead55", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/docs/extras/integrations/llms/fireworks.ipynb b/docs/extras/integrations/llms/fireworks.ipynb index 0ea9aaf93f76e..243f2a3f5a8be 100644 --- a/docs/extras/integrations/llms/fireworks.ipynb +++ b/docs/extras/integrations/llms/fireworks.ipynb @@ -61,29 +61,7 @@ "\n", "You can use the LLMs to call the model for specified prompt(s). \n", "\n", - "Currently supported models: \n", - "\n", - "* Falcon\n", - " * `accounts/fireworks/models/falcon-7b`\n", - " * `accounts/fireworks/models/falcon-40b-w8a16`\n", - "* Llama 2\n", - " * `accounts/fireworks/models/llama-v2-7b`\n", - " * `accounts/fireworks/models/llama-v2-7b-w8a16`\n", - " * `accounts/fireworks/models/llama-v2-7b-chat`\n", - " * `accounts/fireworks/models/llama-v2-7b-chat-w8a16`\n", - " * `accounts/fireworks/models/llama-v2-13b`\n", - " * `accounts/fireworks/models/llama-v2-13b-w8a16`\n", - " * `accounts/fireworks/models/llama-v2-13b-chat`\n", - " * `accounts/fireworks/models/llama-v2-13b-chat-w8a16`\n", - " * `accounts/fireworks/models/llama-v2-70b-chat`\n", - " * `accounts/fireworks/models/llama-v2-70b-chat-w8a16`\n", - "* StarCoder\n", - " * `accounts/fireworks/models/starcoder-1b-w8a16`\n", - " * `accounts/fireworks/models/starcoder-3b-w8a16`\n", - " * `accounts/fireworks/models/starcoder-7b-w8a16`\n", - " * `accounts/fireworks/models/starcoder-16b-w8a16`\n", - "\n", - "See the full, most up-to-date list on [app.fireworks.ai](https://app.fireworks.ai)." + "See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)." ] }, { @@ -98,21 +76,15 @@ "text": [ "\n", "\n", - "Who's the best quarterback in the NFL?\n", - "\n", - "Well, that's a tough question. There are a lot of great quarterbacks in the league right now, and it's hard to say who's the absolute best. But if I had to choose, I'd say that Tom Brady is probably the best quarterback in the NFL.\n", - "\n", - "Now, I know some people might say that Aaron Rodgers or Drew Brees or Patrick Mahomes is the best, and they have valid arguments. But here's why I think Tom Brady is the best:\n", + "It's a question that's been debated for years, and there are plenty of strong candidates. Here are some of the top quarterbacks in the league right now:\n", "\n", - "First of all, he's incredibly consistent. He's been playing at an elite level for almost two decades now, and he's never really had a bad season. He's always been able to adapt to whatever situation he's in, and he's always been able to find a way to win.\n", + "1. Tom Brady (New England Patriots): Brady is widely considered one of the greatest quarterbacks of all time, and for good reason. He's led the Patriots to six Super Bowl wins and has been named Super Bowl MVP four times. He's known for his precision passing and ability to read defenses.\n", + "2. Aaron Rodgers (Green Bay Packers): Rodgers is another top-tier quarterback who's known for his accuracy and ability to make plays outside of the pocket. He's led the Packers to a Super Bowl win and has been named NFL MVP twice.\n", + "3. Drew Brees (New Orleans Saints): Brees is one of the most prolific passers in NFL history, and he's shown no signs of slowing down. He's led the Saints to a Super Bowl win and has been named NFL MVP once.\n", + "4. Russell Wilson (Seattle Seahawks): Wilson is a dynamic quarterback who's known for his ability to make plays with his legs and his arm. He's led the Seahawks to a Super Bowl win and has been named NFL MVP once.\n", + "5. Patrick Mahomes (Kansas City Chiefs): Mahomes is a young quarterback who's quickly become one of the best in the league. He led the Chiefs to a Super Bowl win last season and has been named NFL MVP twice. He's known for his incredible arm talent and ability to make plays outside of the pocket.\n", "\n", - "Second, he's got an incredible work ethic. He's always been known for his dedication to the game, and he's always been willing to put in the extra work to get better. He's got a great understanding of the game, and he's always looking for ways to improve.\n", - "\n", - "Third, he's got a great supporting cast. He's played with some of the best players in the league, and he's always been able to get the most out of his teammates. He's got a great offensive line, and he's got some of the best receivers and running backs in the league.\n", - "\n", - "Now, I know some people might say that Tom Brady's success is all because of his supporting cast, and that he's not as good as some of the other quarterbacks in the league. But I think that's a bit unfair. Sure, he's had some great players around him, but he's also been the one who's led his team to six Super Bowl wins. He's the one who's made the plays when it counts, and he's the one who's always come through in the clutch.\n", - "\n", - "So, there you have it. That's why I think Tom Brady is the best quarterback in the NFL. He's consistent, he's got a great work ethic, and he's got a great supporting\n" + "Of course, there are other great quarterbacks in the league as well, such as Ben Roethlisberger, Matt Ryan, and Deshaun Watson. Ultimately, the \"best\" quarterback is a matter of personal opinion and depends on how you define \"best.\" Some people might value accuracy and precision passing, while others might prefer a quarterback who can make plays with their legs. Either way, the NFL is filled with talented quarterbacks who are making incredible plays every week.\n" ] } ], @@ -132,7 +104,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[[Generation(text=\"\\n\\nCertainly, it is a matter of debate and personal opinion. But if we look at the stats and performances of various cricketers in 2016, here are some of the top contenders for the title of the best cricket player of the year:\\n\\n1. Virat Kohli (India): Kohli had a phenomenal year in 2016, scoring over 1,000 runs in Test cricket, including 11 centuries. He also averaged over 70 in ODIs and was instrumental in India's success in the World T20.\\n2. Steve Smith (Australia): Smith had a great year as well, scoring over 1,000 runs in Test cricket and averaging over 60. He also led Australia to a Test series victory in Sri Lanka and the West Indies.\\n3. KL Rahul (India): Rahul had a breakout year in 2016, scoring three centuries in Test cricket and averaging over 50. He also scored a century in the World T20 and was named the ICC Emerging Cricketer of the Year.\\n4. Joe Root (England): Root had a solid year in 2016, scoring over 1,000 runs in Test cricket and averaging over 50. He also led England to a Test series victory in Bangladesh.\\n5. David Warner (Australia): Warner had a great year in limited-overs cricket, scoring over 1,000 runs in ODIs and averaging over 60. He also scored a century in the World T20.\\n\\nOf course, there are other great cricketers who had excellent years in 2016, such as Quinton de Kock (South Africa), AB de Villiers (South Africa), and Tamim Iqbal (Bangladesh). But based on their stats and performances, these five players are among the top contenders for the title of the best cricket player of 2016.\", generation_info=None)], [Generation(text=\"\\n\\nWho's the best basketball player in the league? \", generation_info=None)]]\n" + "[[Generation(text=\"\\n\\nNote: This is a subjective question, and the answer will depend on individual opinions and perspectives.\\n\\nThere are many great cricket players, and it's difficult to identify a single best player. However, here are some of the top performers in 2016:\\n\\n1. Virat Kohli (India): Kohli had an outstanding year in all formats of the game, scoring heavily in Tests, ODIs, and T20Is. He was especially impressive in the Test series against England, where he scored four centuries and averaged over 100.\\n2. Steve Smith (Australia): Smith had a phenomenal year as well, leading Australia to a Test series win in India and averaging over 100 in the longer format. He also scored a century in the ODI series against Pakistan.\\n3. Kane Williamson (New Zealand): Williamson had a consistent year, scoring heavily in all formats and leading New Zealand to a Test series win against Australia. He also won the ICC Test Player of the Year award.\\n4. Joe Root (England): Root had a solid year, scoring three hundreds in the Test series against Pakistan and India, and averaging over 50 in Tests.\\n5. AB de Villiers (South Africa): De Villiers had a brilliant year in ODIs, scoring four hundreds and averaging over 100. He also had a good year in Tests, scoring two hundreds and averaging over 50.\\n6. Quinton de Kock (South Africa): De Kock had a great year behind the wickets, scoring heavily in all formats and averaging over 50 in Tests.\\n7. Rohit Sharma (India): Sharma had a fantastic year in ODIs, scoring four hundreds and averaging over 100. He also had a good year in Tests, scoring two hundreds and averaging over 40.\\n8. David Warner (Australia): Warner had a great year in ODIs, scoring three hundreds and averaging over 100. He also had a good year in Tests, scoring two hundreds and averaging over 40.\\n\\nThese are just a few examples of top performers in 2016, and opinions on the best player will vary depending on individual perspectives\", generation_info=None)], [Generation(text='\\n\\nThere are a lot of great players in the NBA, and opinions on who\\'s the best can vary depending on personal preferences and criteria for evaluation. However, here are some of the top candidates for the title of best basketball player in the league based on their recent performances and achievements:\\n\\n1. LeBron James: James is a four-time NBA champion and four-time MVP, and is widely regarded as one of the greatest players of all time. He has led the Los Angeles Lakers to the best record in the Western Conference this season and is averaging 25.7 points, 7.9 rebounds, and 7.4 assists per game.\\n2. Giannis Antetokounmpo: Antetokounmpo, known as the \"Greek Freak,\" is a dominant force in the paint and has led the Milwaukee Bucks to the best record in the Eastern Conference. He is averaging 30.5 points, 12.6 rebounds, and 5.9 assists per game, and is a strong contender for the MVP award.\\n3. Stephen Curry: Curry is a three-time NBA champion and two-time MVP, and is known for his incredible shooting ability. He has led the Golden State Warriors to the playoffs despite injuries to key players, and is averaging 23.5 points, 5.2 rebounds, and 5.2 assists per game.\\n4. Kevin Durant: Durant is a two-time NBA champion and four-time scoring champion, and is one of the most skilled scorers in the league. He has led the Brooklyn Nets to the playoffs in their first season since moving from New Jersey, and is averaging 27.2 points, 7.2 rebounds, and 6.4 assists per game.\\n5. James Harden: Harden is a three-time scoring champion and has led the Houston Rockets to the playoffs for the past eight seasons. He is averaging 35.4 points, 8.3 rebounds, and 7.5 assists per game, and is a strong contender for the MVP award.\\n\\nUltimately, determining the best basketball player in the league is subjective and depends on individual opinions and criteria. However, these five players are among', generation_info=None)]]\n" ] } ], @@ -156,7 +128,7 @@ "output_type": "stream", "text": [ "\n", - "Kansas City, Missouri, experiences a continental climate with cold winter\n" + "Kansas City's weather in December can be quite chilly,\n" ] } ], @@ -188,14 +160,15 @@ "text": [ "\n", "\n", - "Assistant: Well, there are a few options! Here are a few ideas for a company that makes football helmets:\n", + "Assistant: That's a great question! There are many factors to consider when choosing a name for a company that makes football helmets. Here are a few suggestions:\n", "\n", - "1. Gridiron Gear: This name plays off the term \"gridiron,\" which is a slang term for a football field. It also has a strong, rugged sound to it, which could appeal to customers who are looking for high-quality football helmets.\n", - "2. Helmet Haven: This name conveys a sense of safety and protection, which is important for customers who are looking for a reliable football helmet. It also has a catchy, memorable sound to it.\n", - "3. Touchdown Titan: This name has a fun, energetic sound to it, which could appeal to customers who are looking for a football helmet that's both functional and stylish. It also plays off the idea of scoring a touchdown, which is a common goal in football.\n", - "4. Football Fusion: This name combines the idea of football with the concept of fusion, which could imply a blending of different materials or technologies to create a high-quality helmet. It also has a modern, cutting-edge sound to it.\n", + "1. Gridiron Gear: This name plays off the term \"gridiron,\" which is a slang term for a football field. It also suggests that the company's products are high-quality and durable, like gear used in a gridiron game.\n", + "2. Helmet Headquarters: This name is straightforward and to the point. It clearly communicates that the company is a leading manufacturer of football helmets.\n", + "3. Tackle Tough: This name plays off the idea of tackling a tough opponent on the football field. It suggests that the company's helmets are designed to protect players from even the toughest hits.\n", + "4. Block Breakthrough: This name is a play on words that suggests the company's helmets are breaking through the competition. It also implies that the company is innovative and forward-thinking.\n", + "5. First Down Fashion: This name combines the idea of scoring a first down on the football field with the idea of fashionable clothing. It suggests that the company's helmets are not only functional but also stylish.\n", "\n", - "I hope these suggestions help inspire you as you come up with a name for your company!\n" + "I hope these suggestions help you come up with a great name for your company!\n" ] } ], @@ -209,11 +182,115 @@ "print(output)" ] }, + { + "cell_type": "markdown", + "id": "25812db3-23a6-41dd-8636-5a49c52bb6eb", + "metadata": {}, + "source": [ + "# Run Stream" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "26d67ecf-9290-4ec2-8b39-ff17fc99620f", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Tom Brady, Aaron Rod\n", + "gers, or Drew Bre\n", + "es?\n", + "Some people might\n", + " say Tom Brady, who\n", + " has won six Super Bowls\n", + " and four Super Bowl MVP\n", + " awards, is the best quarter\n", + "back in the NFL. O\n", + "thers might argue that Aaron\n", + " Rodgers, who has led\n", + " his team to a Super Bowl\n", + " victory and has been named the\n", + " NFL MVP twice, is\n", + " the best. Still, others\n", + " might say that Drew Bre\n", + "es, who holds the NFL\n", + " record for most career passing yards\n", + " and has led his team to\n", + " a Super Bowl victory, is\n", + " the best.\n", + "But what\n", + " if I told you there'\n", + "s actually a fourth quarterback\n", + " who could make a strong case\n", + " for being the best in the\n", + " NFL? Meet Russell Wilson\n", + ", the Seattle Seahaw\n", + "ks' dynamic signal-call\n", + "er who has led his team\n", + " to a Super Bowl victory and\n", + " has been named the NFL M\n", + "VP twice.\n", + "Wilson\n", + " has a unique combination of physical\n", + " and mental skills that set him\n", + " apart from other quarterbacks\n", + " in the league. He'\n", + "s incredibly athletic,\n", + " with the ability to make plays\n", + " with his feet and his arm\n", + ", and he's also\n", + " highly intelligent, with a\n", + " quick mind and the ability to\n", + " read defenses like a pro\n", + ".\n", + "But what really\n", + " sets Wilson apart is his\n", + " leadership ability. He'\n", + "s a natural-born\n", + " leader who has a way\n", + " of inspiring his team\n", + "mates and getting them\n", + " to buy into his vision\n", + " for the game. He\n", + "'s also an excellent\n", + " communicator, who can\n", + " articulate his strategy\n", + " and game plan in a\n", + " way that his teamm\n", + "ates can understand and execute\n", + ".\n", + "So, who\n", + "'s the best quarter\n", + "back in the NFL?\n", + " It's hard to\n", + " say for sure, but\n", + " if you ask me,\n", + " Russell Wilson is definitely in\n", + " the conversation. He'\n", + "s got the physical skills\n", + ", the mental skills,\n", + " and the leadership ability to\n", + " be the best of the\n", + " best.\n" + ] + } + ], + "source": [ + "llm = Fireworks()\n", + "generator = llm.stream(\"Who's the best quarterback in the NFL?\")\n", + "\n", + "for token in generator:\n", + " print(token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3a35e0b-c875-493a-8143-d802d273247c", + "metadata": {}, "outputs": [], "source": [] } From 93a875970681574e03a89399d1495ed1b1d4a8a0 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Thu, 14 Sep 2023 18:45:35 +0000 Subject: [PATCH 12/14] Add comments --- libs/langchain/langchain/chat_models/fireworks.py | 7 +++++-- libs/langchain/langchain/llms/fireworks.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index cafceed719c1f..9153b79d3f4c3 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -33,6 +33,7 @@ def _convert_delta_to_message_chunk( _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] ) -> BaseMessageChunk: + """Convert a delta response to a message chunk.""" role = _dict.role content = _dict.content or "" additional_kwargs = {} @@ -52,6 +53,7 @@ def _convert_delta_to_message_chunk( def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + """Convert a dict response to a message.""" role = _dict.role content = _dict.content or "" if role == "user": @@ -78,7 +80,7 @@ class ChatFireworks(BaseChatModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" + """Validate that api key in environment.""" fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) @@ -234,7 +236,7 @@ async def acompletion_with_retry_streaming( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: - """Use tenacity to retry the async completion call.""" + """Use tenacity to retry the completion call for streaming.""" retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -252,6 +254,7 @@ def _create_retry_decorator( Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] ] = None, ) -> Callable[[Any], Any]: + """Define retry mechanism.""" errors = [ fireworks.client.error.RateLimitError, fireworks.client.error.ServiceUnavailableError, diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index c24f7c84c20a4..86eaf8bf1e934 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -36,7 +36,7 @@ class Fireworks(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" + """Validate that api key in environment.""" fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) @@ -55,6 +55,7 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: + """Run the LLM on the given prompt and input.""" params = { "model": self.model, "prompt": prompt, @@ -71,6 +72,7 @@ async def _acall( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: + """Run the LLM on the given prompt and input.""" params = { "model": self.model, "prompt": prompt, @@ -190,7 +192,7 @@ async def acompletion_with_retry_streaming( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: - """Use tenacity to retry the completion call.""" + """Use tenacity to retry the completion call for streaming.""" retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -208,6 +210,7 @@ def _create_retry_decorator( Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] ] = None, ) -> Callable[[Any], Any]: + """Define retry mechanism.""" errors = [ fireworks.client.error.RateLimitError, fireworks.client.error.ServiceUnavailableError, From d820b63441656f79ecaf435a0d2f08c7db24378f Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Thu, 14 Sep 2023 18:48:46 +0000 Subject: [PATCH 13/14] Refactor --- libs/langchain/langchain/chat_models/fireworks.py | 2 +- libs/langchain/langchain/llms/fireworks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 9153b79d3f4c3..0922e3b338820 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -73,7 +73,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: class ChatFireworks(BaseChatModel): """Fireworks Chat models.""" - model = "accounts/fireworks/models/llama-v2-7b-chat" + model: str = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} fireworks_api_key: Optional[str] = None max_retries: int = 20 diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 86eaf8bf1e934..fdea4019dbb49 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -29,7 +29,7 @@ def _stream_response_to_generation_chunk( class Fireworks(LLM): """Fireworks models.""" - model = "accounts/fireworks/models/llama-v2-7b-chat" + model: str = "accounts/fireworks/models/llama-v2-7b-chat" model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} fireworks_api_key: Optional[str] = None max_retries: int = 20 From 648aebb403ada4de330b77a5ccad0483d6647b32 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Thu, 14 Sep 2023 20:06:10 +0000 Subject: [PATCH 14/14] Add dependency for fireworks-ai --- libs/langchain/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 3d0661332397c..8a715f393161e 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -129,6 +129,7 @@ markdownify = {version = "^0.11.6", optional = true} assemblyai = {version = "^0.17.0", optional = true} dashvector = {version = "^1.0.1", optional = true} sqlite-vss = {version = "^0.1.2", optional = true} +fireworks-ai = {version = "^0.4.1", optional = true} [tool.poetry.group.test.dependencies]