Skip to content

Commit

Permalink
enforce strings in response outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanbouchard committed Dec 12, 2024
1 parent 9513440 commit bbe6c38
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ We can use `ResponseGenerator.generate_responses` to generate 25 responses for e
from langfair.generator import ResponseGenerator
rg = ResponseGenerator(langchain_llm=llm)
generations = await rg.generate_responses(prompts=prompts, count=25)
responses = [str(r) for r in generations["data"]["response"]]
duplicated_prompts = [str(r) for r in generations["data"]["prompt"]] # so prompts correspond to responses
responses = generations["data"]["response"]
duplicated_prompts = generations["data"]["prompt"] # so prompts correspond to responses
```

##### Compute toxicity metrics
Expand Down Expand Up @@ -96,8 +96,8 @@ cg = CounterfactualGenerator(langchain_llm=llm)
cf_generations = await cg.generate_responses(
prompts=prompts, attribute='gender', count=25
)
male_responses = [str(r) for r in cf_generations['data']['male_response']]
female_responses = [str(r) for r in cf_generations['data']['female_response']]
male_responses = cf_generations['data']['male_response']
female_responses = cf_generations['data']['female_response']
```

Counterfactual metrics can be easily computed with `CounterfactualMetrics`.
Expand Down
3 changes: 2 additions & 1 deletion langfair/generator/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ async def generate_responses(
tasks,
duplicated_prompts_dict[prompt_key],
) = self._create_tasks(chain=chain, prompts=prompts_dict[prompt_key])
responses_dict[group + "_response"] = await asyncio.gather(*tasks)
tmp_responses = await asyncio.gather(*tasks)
responses_dict[group + "_response"] = [str(r) for r in tmp_responses]
# stop = time.time()

non_completion_rate = len(
Expand Down
2 changes: 1 addition & 1 deletion langfair/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def generate_responses(
return {
"data": {
"prompt": duplicated_prompts,
"response": responses,
"response": [str(r) for r in responses],
},
"metadata": {
"non_completion_rate": non_completion_rate,
Expand Down

0 comments on commit bbe6c38

Please sign in to comment.