Skip to content

Commit 7c84aa7

Browse files
Merge pull request #85 from cvs-health/db/ftu_check_method
New FTU check method
2 parents 7cf6cef + 0f5072d commit 7c84aa7

File tree

2 files changed

+192
-86
lines changed

2 files changed

+192
-86
lines changed

examples/evaluations/text_generation/counterfactual_metrics_demo.ipynb

+80-69
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
"source": [
2121
"Content\n",
2222
"1. [Introduction](#section1')\n",
23-
"2. [Generate Demo Dataset](#section2')\n",
23+
"2. [Generate Counterfactual Dataset](#section2')<br>\n",
24+
" 2.1 [Check fairness through unawareness](#section2-1')<br>\n",
25+
" 2.2 [Generate counterfactual responses](#section2-2')\n",
2426
"3. [Assessment](#section3')<br>\n",
2527
" 3.1 [Lazy Implementation](#section3-1')<br>\n",
2628
" 3.2 [Separate Implementation](#section3-2')\n",
@@ -120,7 +122,7 @@
120122
"metadata": {},
121123
"source": [
122124
"<a id='section2'></a>\n",
123-
"## 2. Generate Demo Dataset"
125+
"## 2. Generate Counterfactual Dataset"
124126
]
125127
},
126128
{
@@ -163,64 +165,15 @@
163165
"tags": []
164166
},
165167
"source": [
166-
"### Counterfactual Dataset Generator\n",
168+
"#### Counterfactual Dataset Generator\n",
167169
"***\n",
168-
"##### `CounterfactualGenerator()` - Class for generating data for counterfactual discrimination assessment (class)\n",
170+
"##### `CounterfactualGenerator()` - Used for generating data for counterfactual fairness assessment (class)\n",
169171
"\n",
170172
"**Class Attributes:**\n",
171173
"\n",
172-
"- `langchain_llm` (**langchain llm (Runnable), default=None**) A langchain llm object to get passed to LLMChain `llm` argument. \n",
174+
"- `langchain_llm` (**langchain llm (Runnable), default=None**) A LangChain llm object to get passed to LangChain `RunnableSequence`. \n",
173175
"- `suppressed_exceptions` (**tuple, default=None**) Specifies which exceptions to handle as 'Unable to get response' rather than raising the exception\n",
174-
"- `max_calls_per_min` (**deprecated as of 0.2.0**) Use LangChain's InMemoryRateLimiter instead.\n",
175-
"\n",
176-
"**Methods:**\n",
177-
"\n",
178-
"1. `parse_texts()` - Parses a list of texts for protected attribute words and names\n",
179-
"\n",
180-
" **Method Parameters:**\n",
181-
"\n",
182-
" - `text` - (**string**) A text corpus to be parsed for protected attribute words and names\n",
183-
" - `attribute` - (**{'race','gender','name'}**) Specifies what to parse for among race words, gender words, and names\n",
184-
" - `custom_list` - (**List[str], default=None**) Custom list of tokens to use for parsing prompts. Must be provided if attribute is None.\n",
185-
" \n",
186-
" **Returns:**\n",
187-
" - list of results containing protected attribute words found (**list**)\n",
188-
"\n",
189-
"2. `create_prompts()` - Creates counterfactual prompts by counterfactual substitution\n",
190-
"\n",
191-
" **Method Parameters:**\n",
192-
"\n",
193-
" - `prompts` - (**List of strings**) A list of prompts on which counterfactual substitution and response generation will be done\n",
194-
" - `attribute` - (**{'gender', 'race'}, default=None**) Specifies what to parse for among race words and gender words. Must be specified if custom_list is None.\n",
195-
" - `custom_dict` - (**Dict[str, List[str]], default=None**) A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys should correspond to groups. Must be provided if attribute is None. For example: {'male': ['he', 'him', 'woman'], 'female': ['she', 'her', 'man']}\n",
196-
" subset_prompts : bool, default=True\n",
197-
" \n",
198-
" **Returns:**\n",
199-
" - list of prompts on which counterfactual substitution was completed (**list**)\n",
200-
" \n",
201-
"3. `neutralize_tokens()` - Neutralize gender and race words contained in a list of texts. Replaces gender words with a gender-neutral equivalent and race words with \"[MASK]\".\n",
202-
"\n",
203-
" **Method Parameters:**\n",
204-
"\n",
205-
" - `text_list` - (**List of strings**) A list of texts on which gender or race neutralization will occur\n",
206-
" - `attribute` - (**{'gender', 'race'}, default='gender'**) Specifies whether to use race or gender for for neutralization\n",
207-
"\n",
208-
" **Returns:**\n",
209-
" - list of texts neutralized with respect to race or gender (**list**)\n",
210-
"\n",
211-
"4. `generate_responses()` - Creates counterfactual prompts obtained by counterfactual substitution and generates responses asynchronously. \n",
212-
"\n",
213-
" **Method Parameters:**\n",
214-
"\n",
215-
" - `prompts` - (**List of strings**) A list of prompts on which counterfactual substitution and response generation will be done\n",
216-
" - `attribute` - (**{'gender', 'race'}, default='gender'**) Specifies whether to use race or gender for counterfactual substitution\n",
217-
" - `system_prompt` - (**str, default=\"You are a helpful assistant.\"**) Specifies system prompt for generation \n",
218-
" - `count` - (**int, default=25**) Specifies number of responses to generate for each prompt.\n",
219-
" - `custom_dict` - (**Dict[str, List[str]], default=None**) A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys should correspond to groups. Must be provided if attribute is None. For example: {'male': ['he', 'him', 'woman'], 'female': ['she', 'her', 'man']}\n",
220-
"\n",
221-
" **Returns:** A dictionary with two keys: `data` and `metadata`.\n",
222-
" - `data` (**dict**) A dictionary containing the prompts and responses.\n",
223-
" - `metadata` (**dict**) A dictionary containing metadata about the generation process, including non-completion rate, temperature, count, original prompts, and identified proctected attribute words."
176+
"- `max_calls_per_min` (**deprecated as of 0.2.0**) Use LangChain's InMemoryRateLimiter instead."
224177
]
225178
},
226179
{
@@ -366,7 +319,32 @@
366319
"cell_type": "markdown",
367320
"metadata": {},
368321
"source": [
369-
"For illustration, this notebook assesses with 'race' as the protected attribute, but metrics can be evaluated for 'gender' or other custom protected attributes in the same way. First, the above mentioned `parse_texts` method is used to identify the input prompts that contain protected attribute words. \n",
322+
"<a id='section2-1'></a>\n",
323+
"### 2.1 Check fairness through unawareness"
324+
]
325+
},
326+
{
327+
"cell_type": "markdown",
328+
"metadata": {},
329+
"source": [
330+
"#### `CounterfactualGenerator.check_ftu()` - Parses prompts to check for fairness through unawareness. Returns dictionary with prompts, corresponding attribute words found, and applicable metadata. \n",
331+
"\n",
332+
"**Method Parameters:**\n",
333+
"\n",
334+
"- `text` - (**string**) A text corpus to be parsed for protected attribute words and names\n",
335+
"- `attribute` - (**{'race','gender','name'}**) Specifies what to parse for among race words, gender words, and names\n",
336+
"- `custom_list` - (**List[str], default=None**) Custom list of tokens to use for parsing prompts. Must be provided if attribute is None.\n",
337+
"- `subset_prompts` - (**bool, default=True**) Indicates whether to return all prompts or only those containing attribute words\n",
338+
"\n",
339+
"**Returns:**\n",
340+
"- dictionary with prompts, corresponding attribute words found, and applicable metadata (**dict**)"
341+
]
342+
},
343+
{
344+
"cell_type": "markdown",
345+
"metadata": {},
346+
"source": [
347+
"For illustration, this notebook assesses with 'race' as the protected attribute, but metrics can be evaluated for 'gender' or other custom protected attributes in the same way. First, the above mentioned `check_ftu` method is used to check for fairness through unawareness, i.e. whether prompts contain mentions of protected attribute words. In the returned object, prompts are subset to retain only those that contain protected attribute words. \n",
370348
"\n",
371349
"Note: We recommend using atleast 1000 prompts that contain protected attribute words for better estimates. Otherwise, increase `count` attribute of `CounterfactualGenerator` class generate more responses."
372350
]
@@ -456,21 +434,54 @@
456434
],
457435
"source": [
458436
"# Check for fairness through unawareness\n",
459-
"attribute = 'race'\n",
460-
"df = pd.DataFrame({'prompt': prompts})\n",
461-
"df[attribute + '_words'] = cdg.parse_texts(texts=prompts, attribute=attribute)\n",
462-
"\n",
463-
"# Remove input prompts that doesn't include a race word\n",
464-
"race_prompts = df[df['race_words'].apply(lambda x: len(x) > 0)][['prompt','race_words']]\n",
465-
"print(f\"Race words found in {len(race_prompts)} prompts\")\n",
437+
"ftu_result = cdg.check_ftu(\n",
438+
" prompts=prompts,\n",
439+
" attribute='race',\n",
440+
" subset_prompts=True\n",
441+
")\n",
442+
"race_prompts = pd.DataFrame(ftu_result[\"data\"]).rename(columns={'attribute_words': 'race_words'})\n",
466443
"race_prompts.tail(5)"
467444
]
468445
},
469446
{
470447
"cell_type": "markdown",
471448
"metadata": {},
472449
"source": [
473-
"Generate the model response on the input prompts using `generate_responses` method."
450+
"As seen above, this use case does not satisfy fairness through unawareness, since 246 prompts contain mentions of race words."
451+
]
452+
},
453+
{
454+
"cell_type": "markdown",
455+
"metadata": {},
456+
"source": [
457+
"<a id='section2-2'></a>\n",
458+
"### 2.2 Generate counterfactual responses"
459+
]
460+
},
461+
{
462+
"cell_type": "markdown",
463+
"metadata": {},
464+
"source": [
465+
"#### `CounterfactualGenerator.generate_responses()` - Creates counterfactual prompts obtained by counterfactual substitution and generates responses asynchronously. \n",
466+
"\n",
467+
"**Method Parameters:**\n",
468+
"\n",
469+
"- `prompts` - (**List of strings**) A list of prompts on which counterfactual substitution and response generation will be done\n",
470+
"- `attribute` - (**{'gender', 'race'}, default='gender'**) Specifies whether to use race or gender for counterfactual substitution\n",
471+
"- `system_prompt` - (**str, default=\"You are a helpful assistant.\"**) Specifies system prompt for generation \n",
472+
"- `count` - (**int, default=25**) Specifies number of responses to generate for each prompt.\n",
473+
"- `custom_dict` - (**Dict[str, List[str]], default=None**) A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys should correspond to groups. Must be provided if attribute is None. For example: {'male': ['he', 'him', 'woman'], 'female': ['she', 'her', 'man']}\n",
474+
"\n",
475+
"**Returns:** A dictionary with two keys: `data` and `metadata`.\n",
476+
"- `data` (**dict**) A dictionary containing the prompts and responses.\n",
477+
"- `metadata` (**dict**) A dictionary containing metadata about the generation process, including non-completion rate, temperature, count, original prompts, and identified proctected attribute words."
478+
]
479+
},
480+
{
481+
"cell_type": "markdown",
482+
"metadata": {},
483+
"source": [
484+
"Create counterfactual input prompts and generate corresponding LLM responses using `generate_responses` method."
474485
]
475486
},
476487
{
@@ -566,7 +577,7 @@
566577
],
567578
"source": [
568579
"generations = await cdg.generate_responses(\n",
569-
" prompts=df['prompt'], attribute='race', count=1\n",
580+
" prompts=race_prompts['prompt'], attribute='race', count=1\n",
570581
")\n",
571582
"output_df = pd.DataFrame(generations['data'])\n",
572583
"output_df.head(1)"
@@ -617,7 +628,7 @@
617628
"cell_type": "markdown",
618629
"metadata": {},
619630
"source": [
620-
"### `CounterfactualMetrics()` - Calculate all the counterfactual metrics (class)\n",
631+
"#### `CounterfactualMetrics()` - Calculate all the counterfactual metrics (class)\n",
621632
"**Class Attributes:**\n",
622633
"- `metrics` - (**List of strings/Metric objects**) Specifies which metrics to use.\n",
623634
"Default option is a list if strings (`metrics` = [\"Cosine\", \"Rougel\", \"Bleu\", \"Sentiment Bias\"]).\n",
@@ -1206,9 +1217,9 @@
12061217
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m125"
12071218
},
12081219
"kernelspec": {
1209-
"display_name": "langchain",
1220+
"display_name": ".venv",
12101221
"language": "python",
1211-
"name": "langchain"
1222+
"name": "python3"
12121223
},
12131224
"language_info": {
12141225
"codemirror_mode": {
@@ -1220,7 +1231,7 @@
12201231
"name": "python",
12211232
"nbconvert_exporter": "python",
12221233
"pygments_lexer": "ipython3",
1223-
"version": "3.11.10"
1234+
"version": "3.9.6"
12241235
}
12251236
},
12261237
"nbformat": 4,

langfair/generator/counterfactual.py

+112-17
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,7 @@ def parse_texts(
191191
List of length `len(texts)` with each element being a list of identified protected
192192
attribute words in provided text
193193
"""
194-
assert not (custom_list and attribute), """
195-
Either custom_list or attribute must be None.
196-
"""
197-
assert custom_list or attribute in ["race", "gender"], """
198-
If custom_list is None, attribute must be 'race' or 'gender'.
199-
"""
194+
self._validate_attributes(attribute=attribute, custom_list=custom_list)
200195
result = []
201196
for text in texts:
202197
result.append(
@@ -234,13 +229,9 @@ def create_prompts(
234229
dict
235230
Dictionary containing counterfactual prompts
236231
"""
237-
assert not (custom_dict and attribute), """
238-
Either custom_dict or attribute must be None.
239-
"""
240-
assert custom_dict or attribute in [
241-
"gender",
242-
"race",
243-
], "If custom_dict is None, attribute must be 'gender' or 'race'."
232+
self._validate_attributes(
233+
attribute=attribute, custom_dict=custom_dict, for_parsing=False
234+
)
244235

245236
custom_list = (
246237
list(itertools.chain(*custom_dict.values())) if custom_dict else None
@@ -412,13 +403,94 @@ async def generate_responses(
412403
},
413404
}
414405

406+
def check_ftu(
407+
self,
408+
prompts: List[str],
409+
attribute: Optional[str] = None,
410+
custom_list: Optional[List[str]] = None,
411+
subset_prompts: bool = True,
412+
) -> Dict[str, Any]:
413+
"""
414+
Checks for fairness through unawarenss (FTU) based on a list of prompts and a specified protected
415+
attribute
416+
417+
Parameters
418+
----------
419+
prompts : list of strings
420+
A list of prompts to be parsed for protected attribute words
421+
422+
attribute : {'race','gender'}, default=None
423+
Specifies what to parse for among race words and gender words. Must be specified
424+
if custom_list is None
425+
426+
custom_list : List[str], default=None
427+
Custom list of tokens to use for parsing prompts. Must be provided if attribute is None.
428+
429+
subset_prompts : bool, default=True
430+
Indicates whether to return all prompts or only those containing attribute words
431+
432+
Returns
433+
-------
434+
dict
435+
A dictionary with two keys: 'data' and 'metadata'.
436+
'data' : dict
437+
A dictionary containing the prompts and responses.
438+
'prompt' : list
439+
A list of prompts.
440+
'attribute_words' : list
441+
A list of attribute_words in each prompt.
442+
'metadata' : dict
443+
A dictionary containing metadata related to FTU.
444+
'ftu_satisfied' : boolean
445+
Boolean indicator of whether or not prompts satisfy FTU
446+
'filtered_prompt_count' : int
447+
The number of prompts that satisfy FTU.
448+
"""
449+
self._validate_attributes(attribute=attribute, custom_list=custom_list)
450+
attribute_to_print = (
451+
"Protected attribute" if not attribute else attribute.capitalize()
452+
)
453+
attribute_words = self.parse_texts(
454+
texts=prompts, attribute=attribute, custom_list=custom_list,
455+
)
456+
prompts_subset = [
457+
prompt for i, prompt in enumerate(prompts) if attribute_words[i]
458+
]
459+
attribute_words_subset = [
460+
aw for i, aw in enumerate(attribute_words) if attribute_words[i]
461+
]
462+
463+
n_prompts_with_attribute_words = len(prompts_subset)
464+
ftu_satisfied = (n_prompts_with_attribute_words > 0)
465+
ftu_text = " not " if ftu_satisfied else " "
466+
467+
ftu_print = (f"FTU is{ftu_text}satisfied.")
468+
print(f"{attribute_to_print} words found in {len(prompts_subset)} prompts. {ftu_print}")
469+
470+
return {
471+
"data": {
472+
"prompts": prompts_subset if subset_prompts else prompts,
473+
"attribute_words": attribute_words_subset if subset_prompts else attribute_words
474+
},
475+
"metadata": {
476+
"ftu_satisfied": ftu_satisfied,
477+
"n_prompts_with_attribute_words": n_prompts_with_attribute_words,
478+
"attribute": attribute,
479+
"custom_list": custom_list,
480+
"subset_prompts": subset_prompts
481+
}
482+
}
483+
415484
def _subset_prompts(
416485
self,
417486
prompts: List[str],
418487
attribute: Optional[str] = None,
419488
custom_list: Optional[List[str]] = None,
420489
) -> Tuple[List[str], List[List[str]]]:
421-
"""Subset prompts that contain protected attribute words"""
490+
"""
491+
Helper function to subset prompts that contain protected attribute words and also
492+
return the full set of parsing results
493+
"""
422494
attribute_to_print = (
423495
"Protected attribute" if not attribute else attribute.capitalize()
424496
)
@@ -498,9 +570,6 @@ def _sub_from_dict(
498570

499571
return output_dict
500572

501-
################################################################################
502-
# Class for protected attribute scanning and replacing protected attribute words
503-
################################################################################
504573
@staticmethod
505574
def _get_race_subsequences(text: str) -> List[str]:
506575
"""Used to check for string sequences"""
@@ -522,3 +591,29 @@ def _replace_race(text: str, target_race: str) -> str:
522591
for subseq in STRICT_RACE_WORDS:
523592
seq = seq.replace(subseq, race_replacement_mapping[subseq])
524593
return seq
594+
595+
@staticmethod
596+
def _validate_attributes(
597+
attribute: Optional[str] = None,
598+
custom_list: Optional[List[str]] = None,
599+
custom_dict: Optional[Dict[str, str]] = None,
600+
for_parsing: bool = True
601+
) -> None:
602+
if for_parsing:
603+
if (custom_list and attribute):
604+
raise ValueError(
605+
"Either custom_list or attribute must be None."
606+
)
607+
if not (custom_list or attribute in ["race", "gender"]):
608+
raise ValueError(
609+
"If custom_list is None, attribute must be 'race' or 'gender'."
610+
)
611+
else:
612+
if (custom_dict and attribute):
613+
raise ValueError(
614+
"Either custom_dict or attribute must be None."
615+
)
616+
if not (custom_dict or attribute in ["race", "gender"]):
617+
raise ValueError(
618+
"If custom_dict is None, attribute must be 'race' or 'gender'."
619+
)

0 commit comments

Comments
 (0)