Skip to content

Commit

Permalink
refactor and include more metadata in check_ftu
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanbouchard committed Jan 1, 2025
1 parent f14c543 commit 0f5072d
Showing 1 changed file with 39 additions and 18 deletions.
57 changes: 39 additions & 18 deletions langfair/generator/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,7 @@ def parse_texts(
List of length `len(texts)` with each element being a list of identified protected
attribute words in provided text
"""
assert not (custom_list and attribute), """
Either custom_list or attribute must be None.
"""
assert custom_list or attribute in ["race", "gender"], """
If custom_list is None, attribute must be 'race' or 'gender'.
"""
self._validate_attributes(attribute=attribute, custom_list=custom_list)
result = []
for text in texts:
result.append(
Expand Down Expand Up @@ -234,13 +229,9 @@ def create_prompts(
dict
Dictionary containing counterfactual prompts
"""
assert not (custom_dict and attribute), """
Either custom_dict or attribute must be None.
"""
assert custom_dict or attribute in [
"gender",
"race",
], "If custom_dict is None, attribute must be 'gender' or 'race'."
self._validate_attributes(
attribute=attribute, custom_dict=custom_dict, for_parsing=False
)

custom_list = (
list(itertools.chain(*custom_dict.values())) if custom_dict else None
Expand Down Expand Up @@ -455,6 +446,7 @@ def check_ftu(
'filtered_prompt_count' : int
The number of prompts that satisfy FTU.
"""
self._validate_attributes(attribute=attribute, custom_list=custom_list)
attribute_to_print = (
"Protected attribute" if not attribute else attribute.capitalize()
)
Expand Down Expand Up @@ -482,7 +474,10 @@ def check_ftu(
},
"metadata": {
"ftu_satisfied": ftu_satisfied,
"n_prompts_with_attribute_words": n_prompts_with_attribute_words
"n_prompts_with_attribute_words": n_prompts_with_attribute_words,
"attribute": attribute,
"custom_list": custom_list,
"subset_prompts": subset_prompts
}
}

Expand All @@ -492,7 +487,10 @@ def _subset_prompts(
attribute: Optional[str] = None,
custom_list: Optional[List[str]] = None,
) -> Tuple[List[str], List[List[str]]]:
"""Subset prompts that contain protected attribute words"""
"""
Helper function to subset prompts that contain protected attribute words and also
return the full set of parsing results
"""
attribute_to_print = (
"Protected attribute" if not attribute else attribute.capitalize()
)
Expand Down Expand Up @@ -572,9 +570,6 @@ def _sub_from_dict(

return output_dict

################################################################################
# Class for protected attribute scanning and replacing protected attribute words
################################################################################
@staticmethod
def _get_race_subsequences(text: str) -> List[str]:
"""Used to check for string sequences"""
Expand All @@ -596,3 +591,29 @@ def _replace_race(text: str, target_race: str) -> str:
for subseq in STRICT_RACE_WORDS:
seq = seq.replace(subseq, race_replacement_mapping[subseq])
return seq

@staticmethod
def _validate_attributes(
attribute: Optional[str] = None,
custom_list: Optional[List[str]] = None,
custom_dict: Optional[Dict[str, str]] = None,
for_parsing: bool = True
) -> None:
if for_parsing:
if (custom_list and attribute):
raise ValueError(
"Either custom_list or attribute must be None."
)
if not (custom_list or attribute in ["race", "gender"]):
raise ValueError(
"If custom_list is None, attribute must be 'race' or 'gender'."
)
else:
if (custom_dict and attribute):
raise ValueError(
"Either custom_dict or attribute must be None."
)
if not (custom_dict or attribute in ["race", "gender"]):
raise ValueError(
"If custom_dict is None, attribute must be 'race' or 'gender'."
)

0 comments on commit 0f5072d

Please sign in to comment.