From a78cd6dd933efe92e06da0eedc276137fa30dfcd Mon Sep 17 00:00:00 2001 From: Dylan Bouchard Date: Tue, 31 Dec 2024 15:50:32 -0500 Subject: [PATCH] new method for ftu check --- .../counterfactual_metrics_demo.ipynb | 12 ++-- langfair/generator/counterfactual.py | 71 +++++++++++++++++++ 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/examples/evaluations/text_generation/counterfactual_metrics_demo.ipynb b/examples/evaluations/text_generation/counterfactual_metrics_demo.ipynb index 3f34302..60ccf0a 100644 --- a/examples/evaluations/text_generation/counterfactual_metrics_demo.ipynb +++ b/examples/evaluations/text_generation/counterfactual_metrics_demo.ipynb @@ -456,13 +456,13 @@ ], "source": [ "# Check for fairness through unawareness\n", - "attribute = 'race'\n", "df = pd.DataFrame({'prompt': prompts})\n", - "df[attribute + '_words'] = cdg.parse_texts(texts=prompts, attribute=attribute)\n", - "\n", - "# Remove input prompts that doesn't include a race word\n", - "race_prompts = df[df['race_words'].apply(lambda x: len(x) > 0)][['prompt','race_words']]\n", - "print(f\"Race words found in {len(race_prompts)} prompts\")\n", + "ftu_result = cdg.check_ftu(\n", + " prompts=prompts,\n", + " attribute='race',\n", + " subset_prompts=True\n", + ")\n", + "race_prompts = pd.DataFrame(ftu_result[\"data\"])\n", "race_prompts.tail(5)" ] }, diff --git a/langfair/generator/counterfactual.py b/langfair/generator/counterfactual.py index 33ab444..79f6d64 100644 --- a/langfair/generator/counterfactual.py +++ b/langfair/generator/counterfactual.py @@ -412,6 +412,77 @@ async def generate_responses( }, } + def check_ftu( + self, + prompts: List[str], + attribute: Optional[str] = None, + custom_list: Optional[List[str]] = None, + subset_prompts: bool = True, + ) -> Dict[str, Any]: + """ + Checks for fairness through unawarenss (FTU) based on a list of prompts and a specified protected + attribute + + Parameters + ---------- + prompts : list of strings + A list of prompts to be parsed for protected attribute words + + attribute : {'race','gender'}, default=None + Specifies what to parse for among race words and gender words. Must be specified + if custom_list is None + + custom_list : List[str], default=None + Custom list of tokens to use for parsing prompts. Must be provided if attribute is None. + + Returns + ------- + dict + A dictionary with two keys: 'data' and 'metadata'. + 'data' : dict + A dictionary containing the prompts and responses. + 'prompt' : list + A list of prompts. + 'attribute_words' : list + A list of attribute_words in each prompt. + 'metadata' : dict + A dictionary containing metadata related to FTU. + 'ftu_satisfied' : boolean + Boolean indicator of whether or not prompts satisfy FTU + 'filtered_prompt_count' : int + The number of prompts that satisfy FTU. + """ + attribute_to_print = ( + "Protected attribute" if not attribute else attribute.capitalize() + ) + attribute_words = self.parse_texts( + texts=prompts, + attribute=attribute, + custom_list=custom_list, + subset_prompts=subset_prompts + ) + prompts_subset = [ + prompt for i, prompt in enumerate(prompts) if attribute_words[i] + ] + n_prompts_with_attribute_words = len(prompts_subset) + ftu_satisfied = (n_prompts_with_attribute_words > 0) + + ftu_print = ( + f"FTU is {"not" if ftu_satisfied else ""} satisfied." + ) + print(f"{attribute_to_print} words found in {len(prompts_subset)} prompts. {ftu_print}") + + return { + "data": { + "prompts": prompts_subset if subset_prompts else prompts, + "attribute_words": attribute_words + }, + "metadata": { + "ftu_satisfied": ftu_satisfied, + "n_prompts_with_attribute_words": n_prompts_with_attribute_words + } + } + def _subset_prompts( self, prompts: List[str],