Skip to content

Commit

Permalink
new method for ftu check
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanbouchard committed Dec 31, 2024
1 parent 7cf6cef commit a78cd6d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,13 @@
],
"source": [
"# Check for fairness through unawareness\n",
"attribute = 'race'\n",
"df = pd.DataFrame({'prompt': prompts})\n",
"df[attribute + '_words'] = cdg.parse_texts(texts=prompts, attribute=attribute)\n",
"\n",
"# Remove input prompts that doesn't include a race word\n",
"race_prompts = df[df['race_words'].apply(lambda x: len(x) > 0)][['prompt','race_words']]\n",
"print(f\"Race words found in {len(race_prompts)} prompts\")\n",
"ftu_result = cdg.check_ftu(\n",
" prompts=prompts,\n",
" attribute='race',\n",
" subset_prompts=True\n",
")\n",
"race_prompts = pd.DataFrame(ftu_result[\"data\"])\n",
"race_prompts.tail(5)"
]
},
Expand Down
71 changes: 71 additions & 0 deletions langfair/generator/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,77 @@ async def generate_responses(
},
}

def check_ftu(
self,
prompts: List[str],
attribute: Optional[str] = None,
custom_list: Optional[List[str]] = None,
subset_prompts: bool = True,
) -> Dict[str, Any]:
"""
Checks for fairness through unawarenss (FTU) based on a list of prompts and a specified protected
attribute
Parameters
----------
prompts : list of strings
A list of prompts to be parsed for protected attribute words
attribute : {'race','gender'}, default=None
Specifies what to parse for among race words and gender words. Must be specified
if custom_list is None
custom_list : List[str], default=None
Custom list of tokens to use for parsing prompts. Must be provided if attribute is None.
Returns
-------
dict
A dictionary with two keys: 'data' and 'metadata'.
'data' : dict
A dictionary containing the prompts and responses.
'prompt' : list
A list of prompts.
'attribute_words' : list
A list of attribute_words in each prompt.
'metadata' : dict
A dictionary containing metadata related to FTU.
'ftu_satisfied' : boolean
Boolean indicator of whether or not prompts satisfy FTU
'filtered_prompt_count' : int
The number of prompts that satisfy FTU.
"""
attribute_to_print = (
"Protected attribute" if not attribute else attribute.capitalize()
)
attribute_words = self.parse_texts(
texts=prompts,
attribute=attribute,
custom_list=custom_list,
subset_prompts=subset_prompts
)
prompts_subset = [
prompt for i, prompt in enumerate(prompts) if attribute_words[i]
]
n_prompts_with_attribute_words = len(prompts_subset)
ftu_satisfied = (n_prompts_with_attribute_words > 0)

ftu_print = (
f"FTU is {"not" if ftu_satisfied else ""} satisfied."
)
print(f"{attribute_to_print} words found in {len(prompts_subset)} prompts. {ftu_print}")

return {
"data": {
"prompts": prompts_subset if subset_prompts else prompts,
"attribute_words": attribute_words
},
"metadata": {
"ftu_satisfied": ftu_satisfied,
"n_prompts_with_attribute_words": n_prompts_with_attribute_words
}
}

def _subset_prompts(
self,
prompts: List[str],
Expand Down

0 comments on commit a78cd6d

Please sign in to comment.