From 0f5072d107d01c38eb99edef0df935dd3c6b1682 Mon Sep 17 00:00:00 2001 From: Dylan Bouchard Date: Wed, 1 Jan 2025 09:44:57 -0500 Subject: [PATCH] refactor and include more metadata in check_ftu --- langfair/generator/counterfactual.py | 57 +++++++++++++++++++--------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/langfair/generator/counterfactual.py b/langfair/generator/counterfactual.py index 1fb54352..2903a010 100644 --- a/langfair/generator/counterfactual.py +++ b/langfair/generator/counterfactual.py @@ -191,12 +191,7 @@ def parse_texts( List of length `len(texts)` with each element being a list of identified protected attribute words in provided text """ - assert not (custom_list and attribute), """ - Either custom_list or attribute must be None. - """ - assert custom_list or attribute in ["race", "gender"], """ - If custom_list is None, attribute must be 'race' or 'gender'. - """ + self._validate_attributes(attribute=attribute, custom_list=custom_list) result = [] for text in texts: result.append( @@ -234,13 +229,9 @@ def create_prompts( dict Dictionary containing counterfactual prompts """ - assert not (custom_dict and attribute), """ - Either custom_dict or attribute must be None. - """ - assert custom_dict or attribute in [ - "gender", - "race", - ], "If custom_dict is None, attribute must be 'gender' or 'race'." + self._validate_attributes( + attribute=attribute, custom_dict=custom_dict, for_parsing=False + ) custom_list = ( list(itertools.chain(*custom_dict.values())) if custom_dict else None @@ -455,6 +446,7 @@ def check_ftu( 'filtered_prompt_count' : int The number of prompts that satisfy FTU. """ + self._validate_attributes(attribute=attribute, custom_list=custom_list) attribute_to_print = ( "Protected attribute" if not attribute else attribute.capitalize() ) @@ -482,7 +474,10 @@ def check_ftu( }, "metadata": { "ftu_satisfied": ftu_satisfied, - "n_prompts_with_attribute_words": n_prompts_with_attribute_words + "n_prompts_with_attribute_words": n_prompts_with_attribute_words, + "attribute": attribute, + "custom_list": custom_list, + "subset_prompts": subset_prompts } } @@ -492,7 +487,10 @@ def _subset_prompts( attribute: Optional[str] = None, custom_list: Optional[List[str]] = None, ) -> Tuple[List[str], List[List[str]]]: - """Subset prompts that contain protected attribute words""" + """ + Helper function to subset prompts that contain protected attribute words and also + return the full set of parsing results + """ attribute_to_print = ( "Protected attribute" if not attribute else attribute.capitalize() ) @@ -572,9 +570,6 @@ def _sub_from_dict( return output_dict - ################################################################################ - # Class for protected attribute scanning and replacing protected attribute words - ################################################################################ @staticmethod def _get_race_subsequences(text: str) -> List[str]: """Used to check for string sequences""" @@ -596,3 +591,29 @@ def _replace_race(text: str, target_race: str) -> str: for subseq in STRICT_RACE_WORDS: seq = seq.replace(subseq, race_replacement_mapping[subseq]) return seq + + @staticmethod + def _validate_attributes( + attribute: Optional[str] = None, + custom_list: Optional[List[str]] = None, + custom_dict: Optional[Dict[str, str]] = None, + for_parsing: bool = True + ) -> None: + if for_parsing: + if (custom_list and attribute): + raise ValueError( + "Either custom_list or attribute must be None." + ) + if not (custom_list or attribute in ["race", "gender"]): + raise ValueError( + "If custom_list is None, attribute must be 'race' or 'gender'." + ) + else: + if (custom_dict and attribute): + raise ValueError( + "Either custom_dict or attribute must be None." + ) + if not (custom_dict or attribute in ["race", "gender"]): + raise ValueError( + "If custom_dict is None, attribute must be 'race' or 'gender'." + ) \ No newline at end of file