Skip to content

Commit

Permalink
Merge pull request #2 from VowpalWabbit/fixes
Browse files Browse the repository at this point in the history
Dependency and import fixes
  • Loading branch information
olgavrou authored Aug 22, 2023
2 parents e942330 + c9e9c0e commit 571ee71
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 618 deletions.
26 changes: 19 additions & 7 deletions libs/langchain/langchain/chains/rl_chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from langchain.chains.rl_chain.pick_best_chain import PickBest
import logging

from langchain.chains.rl_chain.base import (
Embed,
BasedOn,
ToSelectFrom,
SelectionScorer,
AutoSelectionScorer,
BasedOn,
Embed,
Embedder,
Policy,
SelectionScorer,
ToSelectFrom,
VwPolicy,
)

import logging
from langchain.chains.rl_chain.pick_best_chain import PickBest


def configure_logger():
Expand All @@ -26,3 +26,15 @@ def configure_logger():


configure_logger()

__all__ = [
"PickBest",
"Embed",
"BasedOn",
"ToSelectFrom",
"SelectionScorer",
"AutoSelectionScorer",
"Embedder",
"Policy",
"VwPolicy",
]
121 changes: 69 additions & 52 deletions libs/langchain/langchain/chains/rl_chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
from abc import ABC, abstractmethod

import vowpal_wabbit_next as vw
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.prompts import BasePromptTemplate

from langchain.pydantic_v1 import Extra, BaseModel, root_validator
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union

from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import BaseModel, Extra, root_validator

if TYPE_CHECKING:
import vowpal_wabbit_next as vw

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,7 +87,7 @@ def EmbedAndKeep(anything):
# helper functions


def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]:
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
return [parser.parse_line(line) for line in input_str.split("\n")]


Expand All @@ -100,7 +100,8 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):

if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
"No variables using 'ToSelectFrom' found in the inputs. \
Please include at least one variable containing a list to select from."
)

based_on = {
Expand All @@ -113,8 +114,11 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):


def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
# go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if
# their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
"""
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
""" # noqa: E501

next_inputs = inputs.copy()
for k, v in next_inputs.items():
if isinstance(v, _ToSelectFrom) or isinstance(v, _BasedOn):
Expand Down Expand Up @@ -173,14 +177,17 @@ def __init__(
self.vw_logger = vw_logger

def predict(self, event: Event) -> Any:
import vowpal_wabbit_next as vw

text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(
parse_lines(text_parser, self.feature_embedder.format(event))
)

def learn(self, event: Event):
vw_ex = self.feature_embedder.format(event)
import vowpal_wabbit_next as vw

vw_ex = self.feature_embedder.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)
Expand Down Expand Up @@ -216,13 +223,18 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
@staticmethod
def get_default_system_prompt() -> SystemMessagePromptTemplate:
return SystemMessagePromptTemplate.from_template(
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n \
You are a strict judge that is called on to rank a response based on \
given criteria. You must respond with your ranking by providing a \
single float within the range [0, 1], 0 being very bad \
response and 1 being very good response."
)

@staticmethod
def get_default_prompt() -> ChatPromptTemplate:
human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{llm_response}".'
human_template = 'Given this based_on "{rl_chain_selected_based_on}" \
as the most important attribute, rank how good or bad this text is: \
"{llm_response}".'
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
chat_prompt = ChatPromptTemplate.from_messages(
Expand Down Expand Up @@ -257,25 +269,36 @@ def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
return resp
except Exception as e:
raise RuntimeError(
f"The llm did not manage to rank the response as expected, there is always the option to try again or tweak the reward prompt. Error: {e}"
f"The auto selection scorer did not manage to score the response, \
there is always the option to try again or tweak the reward prompt.\
Error: {e}"
)


class RLChain(Chain):
"""
RLChain class that utilizes the Vowpal Wabbit (VW) model for personalization.
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
Attributes:
model_loading (bool, optional): If set to True, the chain will attempt to load an existing VW model from the latest checkpoint file in the {model_save_dir} directory (current directory if none specified). If set to False, it will start training from scratch, potentially overwriting existing files. Defaults to True.
large_action_spaces (bool, optional): If set to True and vw_cmd has not been specified in the constructor, it will enable large action spaces
vw_cmd (List[str], optional): Advanced users can set the VW command line to whatever they want, as long as it is compatible with the Type that is specified (Type Enum)
model_save_dir (str, optional): The directory to save the VW model to. Defaults to the current directory.
selection_scorer (SelectionScorer): If set, the chain will check the response using the provided selection_scorer and the VW model will be updated with the result. Defaults to None.
- llm_chain (Chain): Represents the underlying Language Model chain.
- prompt (BasePromptTemplate): The template for the base prompt.
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
- auto_embed (bool): Determines if embedding should be automatic. Default is True.
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
Initialization Attributes:
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
- vw_cmd (List[str], optional): Command line arguments for the VW model.
- policy (VwPolicy): Policy used by the chain.
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
- metrics_step (int): Step for the metrics tracker. Default is -1.
Notes:
The class creates a VW model instance using the provided arguments. Before the chain object is destroyed the save_progress() function can be called. If it is called, the learned VW model is saved to a file in the current directory named `model-<checkpoint>.vw`. Checkpoints start at 1 and increment monotonically.
When making predictions, VW is first called to choose action(s) which are then passed into the prompt with the key `{actions}`. After action selection, the LLM (Language Model) is called with the prompt populated by the chosen action(s), and the response is returned.
"""
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
""" # noqa: E501

llm_chain: Chain

Expand Down Expand Up @@ -303,7 +326,9 @@ def __init__(
super().__init__(*args, **kwargs)
if self.selection_scorer is None:
logger.warning(
"No response validator provided, which means that no reinforcement learning will be done in the RL chain unless update_with_delayed_score is called."
"No selection scorer provided, which means that no \
reinforcement learning will be done in the RL chain \
unless update_with_delayed_score is called."
)
self.policy = policy(
model_repo=ModelRepository(
Expand Down Expand Up @@ -343,7 +368,9 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
or self.selected_based_on_input_key in inputs.keys()
):
raise ValueError(
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."
f"The rl chain does not accept '{self.selected_input_key}' \
or '{self.selected_based_on_input_key}' as input keys, \
they are reserved for internal use during auto reward."
)

@abstractmethod
Expand Down Expand Up @@ -372,13 +399,13 @@ def update_with_delayed_score(
self, score: float, event: Event, force_score=False
) -> None:
"""
Learn will be called with the score specified and the actions/embeddings/etc stored in event
Updates the learned policy with the score provided.
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
"""
""" # noqa: E501
if self.selection_scorer and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
)
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
Expand All @@ -387,10 +414,7 @@ def update_with_delayed_score(

def set_auto_embed(self, auto_embed: bool) -> None:
"""
Set whether the chain should auto embed the inputs or not. If set to False, the inputs will not be embedded and the user will need to embed the inputs themselves before calling run.
Args:
auto_embed (bool): Whether the chain should auto embed the inputs or not.
Sets whether the chain should auto embed the inputs or not.
"""
self.auto_embed = auto_embed

Expand Down Expand Up @@ -435,7 +459,8 @@ def _call(
)
except Exception as e:
logger.info(
f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}"
f"The selection scorer was not able to score, \
and the chain was not able to adjust to this response, error: {e}"
)
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event)
Expand All @@ -446,16 +471,7 @@ def _call(

def save_progress(self) -> None:
"""
This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file.
File Naming Convention:
The file will be named using the pattern `model-<checkpoint>.vw`, where `<checkpoint>` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `<checkpoint>` will be the next in the sequence.
Example:
If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`.
Note:
Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers.
This function should be called to save the state of the learned policy model.
"""
self.policy.save()

Expand Down Expand Up @@ -490,7 +506,8 @@ def embed_string_type(

if namespace is None:
raise ValueError(
"The default namespace must be provided when embedding a string or _Embed object."
"The default namespace must be \
provided when embedding a string or _Embed object."
)

return {namespace: keep_str + join_char.join(map(str, encoded))}
Expand Down Expand Up @@ -530,15 +547,15 @@ def embed(
namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]:
"""
Embeds the actions or context using the SentenceTransformer model
Embeds the actions or context using the SentenceTransformer model (or a model that has an `encode` function)
Attributes:
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
model: (Any, required) The model to use for embedding
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
"""
""" # noqa: E501
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
to_embed, str
):
Expand Down
10 changes: 7 additions & 3 deletions libs/langchain/langchain/chains/rl_chain/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pandas as pd
from typing import Optional
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
import pandas as pd


class MetricsTracker:
Expand All @@ -23,5 +25,7 @@ def on_feedback(self, score: Optional[float]) -> None:
if self._step > 0 and self._i % self._step == 0:
self._history.append({"step": self._i, "score": self.score})

def to_pandas(self) -> pd.DataFrame:
def to_pandas(self) -> "pd.DataFrame":
import pandas as pd

return pd.DataFrame(self._history)
18 changes: 11 additions & 7 deletions libs/langchain/langchain/chains/rl_chain/model_repository.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pathlib import Path
import shutil
import datetime
import vowpal_wabbit_next as vw
from typing import Union, Sequence
import os
import glob
import logging
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Sequence, Union

if TYPE_CHECKING:
import vowpal_wabbit_next as vw

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,14 +37,16 @@ def get_tag(self) -> str:
def has_history(self) -> bool:
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0

def save(self, workspace: vw.Workspace) -> None:
def save(self, workspace: "vw.Workspace") -> None:
with open(self.model_path, "wb") as f:
logger.info(f"storing rl_chain model in: {self.model_path}")
f.write(workspace.serialize())
if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")

def load(self, commandline: Sequence[str]) -> vw.Workspace:
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
import vowpal_wabbit_next as vw

model_data = None
if self.model_path.exists():
with open(self.model_path, "rb") as f:
Expand Down
Loading

0 comments on commit 571ee71

Please sign in to comment.