Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refuel LLM integration with autolabel #595

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/autolabel/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
) -> None:
self.score_type = score_type
self.llm = llm
self.tokens_to_ignore = {"<unk>"}
self.tokens_to_ignore = {"<unk>", "", "\\n"}
self.SUPPORTED_CALCULATORS = {
"logprob_average": self.logprob_average,
"p_true": self.p_true,
Expand All @@ -54,7 +54,7 @@ def logprob_average(
logprob_cumulative, count = 0, 0
for token in logprobs:
token_str = list(token.keys())[0]
if token_str not in self.tokens_to_ignore:
if token_str.strip() not in self.tokens_to_ignore:
logprob_cumulative += (
token[token_str]
if token[token_str] >= 0
Expand Down
6 changes: 4 additions & 2 deletions src/autolabel/few_shot/label_diversity_example_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on label diversity and semantic similarity."""
# Get the docs with the highest similarity for each label.
if self.input_keys:
input_variables = {key: input_variables[key] for key in self.input_keys}
input_variables = {
str(key): str(input_variables[key]) for key in self.input_keys
}
query = " ".join(sorted_values(input_variables))
num_examples_per_label = math.ceil(self.k / self.num_labels)
example_docs = self.vectorstore.label_diversity_similarity_search(
Expand Down Expand Up @@ -146,7 +148,7 @@ def from_examples(
"""
if input_keys:
string_examples = [
" ".join(sorted_values({k: eg[k] for k in input_keys}))
" ".join(sorted_values({str(k): str(eg[k]) for k in input_keys}))
for eg in examples
]
else:
Expand Down
2 changes: 1 addition & 1 deletion src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def plan(
table, show_header=False, console=self.console, styles=COST_TABLE_STYLES
)
self.console.rule("Prompt Example")
self.console.print(f"{prompt_list[0]}")
self.console.print(f"{prompt_list[0]}", markup=False)
self.console.rule()

async def async_run_transform(
Expand Down
41 changes: 21 additions & 20 deletions src/autolabel/models/refuel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
class RefuelLLM(BaseModel):
DEFAULT_PARAMS = {
"max_new_tokens": 128,
"temperature": 0.0,
}

def __init__(
Expand All @@ -41,8 +40,7 @@ def __init__(
self.model_params = {**self.DEFAULT_PARAMS, **model_params}

# initialize runtime
self.BASE_API = "https://refuel-llm.refuel.ai/"
self.SEP_REPLACEMENT_TOKEN = "@@"
self.BASE_API = f"https://llm.refuel.ai/models/{self.model_name}/generate"
self.REFUEL_API_ENV = "REFUEL_API_KEY"
if self.REFUEL_API_ENV in os.environ and os.environ[self.REFUEL_API_ENV]:
self.REFUEL_API_KEY = os.environ[self.REFUEL_API_ENV]
Expand All @@ -60,8 +58,9 @@ def __init__(
)
def _label_with_retry(self, prompt: str) -> requests.Response:
payload = {
"data": {"model_input": prompt, "model_params": {**self.model_params}},
"task": "generate",
"input": prompt,
"params": {**self.model_params},
"confidence": self.config.confidence(),
}
headers = {"refuel_api_key": self.REFUEL_API_KEY}
response = requests.post(self.BASE_API, json=payload, headers=headers)
Expand All @@ -74,20 +73,20 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult:
errors = []
for prompt in prompts:
try:
if self.SEP_REPLACEMENT_TOKEN in prompt:
logger.warning(
f"""Current prompt contains {self.SEP_REPLACEMENT_TOKEN}
which is currently used as a separator token by refuel
llm. It is highly recommended to avoid having any
occurences of this substring in the prompt.
"""
)
separated_prompt = prompt.replace("\n", self.SEP_REPLACEMENT_TOKEN)
response = self._label_with_retry(separated_prompt)
response = json.loads(response.json()["body"]).replace(
self.SEP_REPLACEMENT_TOKEN, "\n"
response = self._label_with_retry(prompt)
response = json.loads(response.json())
generations.append(
[
Generation(
text=response["generated_text"],
generation_info={
"logprobs": {"top_logprobs": response["logprobs"]}
}
if self.config.confidence()
else None,
)
]
)
generations.append([Generation(text=response)])
errors.append(None)
except Exception as e:
# This signifies an error in generating the response using RefuelLLm
Expand All @@ -96,12 +95,14 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult:
)
generations.append([Generation(text="")])
errors.append(
LabelingError(error_type=ErrorType.LLM_PROVIDER_ERROR, error=e)
LabelingError(
error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
)
)
return RefuelLLMResult(generations=generations, errors=errors)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
return 0

def returns_token_probs(self) -> bool:
return False
return True
32 changes: 28 additions & 4 deletions src/autolabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TaskType,
LabelingError,
ErrorType,
ModelProvider,
)
from autolabel.utils import (
get_format_variables,
Expand All @@ -30,6 +31,17 @@ class BaseTask(ABC):
ZERO_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nNow I want you to label the following example:\n{current_example}"
FEW_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nSome examples with their output answers are provided below:\n\n{seed_examples}\n\nNow I want you to label the following example:\n{current_example}"

ZERO_SHOT_TEMPLATE_REFUEL_LLM = """
<s>[INST] <<SYS>>
{task_guidelines}{output_guidelines}
<<SYS>>
{current_example}[/INST]\n"""
FEW_SHOT_TEMPLATE_REFUEL_LLM = """
<s>[INST] <<SYS>>
{task_guidelines}{output_guidelines}\n{seed_examples}
<<SYS>>
{current_example}[/INST]\n"""

# Downstream classes should override these
NULL_LABEL_TOKEN = "NO_LABEL"
DEFAULT_TASK_GUIDELINES = ""
Expand All @@ -39,6 +51,8 @@ class BaseTask(ABC):
def __init__(self, config: AutolabelConfig) -> None:
self.config = config

is_refuel_llm = self.config.provider() == ModelProvider.REFUEL

# Update the default prompt template with the prompt template from the config
self.task_guidelines = (
self.config.task_guidelines() or self.DEFAULT_TASK_GUIDELINES
Expand All @@ -48,14 +62,24 @@ def __init__(self, config: AutolabelConfig) -> None:
)

if self._is_few_shot_mode():
few_shot_template = (
self.FEW_SHOT_TEMPLATE_REFUEL_LLM
if is_refuel_llm
else self.FEW_SHOT_TEMPLATE
)
self.prompt_template = PromptTemplate(
input_variables=get_format_variables(self.FEW_SHOT_TEMPLATE),
template=self.FEW_SHOT_TEMPLATE,
input_variables=get_format_variables(few_shot_template),
template=few_shot_template,
)
else:
zero_shot_template = (
self.ZERO_SHOT_TEMPLATE_REFUEL_LLM
if is_refuel_llm
else self.ZERO_SHOT_TEMPLATE
)
self.prompt_template = PromptTemplate(
input_variables=get_format_variables(self.ZERO_SHOT_TEMPLATE),
template=self.ZERO_SHOT_TEMPLATE,
input_variables=get_format_variables(zero_shot_template),
template=zero_shot_template,
)

self.dataset_generation_guidelines = (
Expand Down
17 changes: 14 additions & 3 deletions src/autolabel/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from autolabel.confidence import ConfidenceCalculator
from autolabel.configs import AutolabelConfig
from autolabel.schema import LLMAnnotation, MetricType, MetricResult
from autolabel.schema import LLMAnnotation, MetricType, MetricResult, ModelProvider
from autolabel.tasks import BaseTask
from autolabel.utils import get_format_variables
from autolabel.tasks.utils import filter_unlabeled_examples
Expand Down Expand Up @@ -60,8 +60,19 @@ def construct_prompt(
)
num_labels = len(labels_list)

is_refuel_llm = self.config.provider() == ModelProvider.REFUEL

if is_refuel_llm:
labels = (
", ".join([f'\\"{i}\\"' for i in labels_list[:-1]])
+ " or "
+ f'\\"{labels_list[-1]}\\"'
)
else:
labels = "\n".join(labels_list)

fmt_task_guidelines = self.task_guidelines.format(
num_labels=num_labels, labels="\n".join(labels_list)
num_labels=num_labels, labels=labels
)

# prepare seed examples
Expand Down Expand Up @@ -91,7 +102,7 @@ def construct_prompt(
return self.prompt_template.format(
task_guidelines=fmt_task_guidelines,
output_guidelines=self.output_guidelines,
seed_examples="\n\n".join(fmt_examples),
seed_examples="\n".join(fmt_examples),
current_example=current_example,
)
else:
Expand Down
12 changes: 11 additions & 1 deletion src/autolabel/tasks/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from autolabel.confidence import ConfidenceCalculator
from autolabel.configs import AutolabelConfig
from autolabel.schema import LLMAnnotation, MetricType, MetricResult, F1Type
from autolabel.schema import (
LLMAnnotation,
MetricType,
MetricResult,
F1Type,
ModelProvider,
)
from autolabel.tasks import BaseTask
from autolabel.tasks.utils import normalize_text
from autolabel.utils import get_format_variables
Expand All @@ -32,6 +38,10 @@ class QuestionAnsweringTask(BaseTask):
GENERATE_EXPLANATION_PROMPT = "You are an expert at providing a well reasoned explanation for the output of a given task. \n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION\nYou will be given an input example and the corresponding output. You will be given a question and an answer. Your job is to provide an explanation for why the answer is correct for the task above.\nThink step by step and generate an explanation. The last line of the explanation should be - So, the answer is <label>.\n{labeled_example}\nExplanation: "

def __init__(self, config: AutolabelConfig) -> None:
is_refuel_llm = config.provider() == ModelProvider.REFUEL
if is_refuel_llm:
self.DEFAULT_OUTPUT_GUIDELINES = ""

super().__init__(config)
self.metrics = [
AccuracyMetric(),
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/llm_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from autolabel.configs import AutolabelConfig
from autolabel.models.anthropic import AnthropicLLM
from autolabel.models.openai import OpenAILLM
Expand Down Expand Up @@ -211,7 +212,7 @@ def __init__(self, resp):
self.resp = resp

def json(self):
return {"body": self.resp}
return self.resp

def raise_for_status(self):
pass
Expand All @@ -222,7 +223,7 @@ def raise_for_status(self):
prompts = ["test1", "test2"]
mocker.patch(
"requests.post",
return_value=PostRequestMockResponse(resp='"Answers"'),
return_value=PostRequestMockResponse(resp='{"generated_text": "Answers"}'),
)
x = model.label(prompts)
assert [i[0].text for i in x.generations] == ["Answers", "Answers"]
Expand All @@ -242,7 +243,7 @@ def test_refuel_return_probs():
model = RefuelLLM(
config=AutolabelConfig(config="tests/assets/banking/config_banking_refuel.json")
)
assert model.returns_token_probs() is False
assert model.returns_token_probs() is True


################### REFUEL TESTS #######################
Loading