|
20 | 20 | "source": [
|
21 | 21 | "Content\n",
|
22 | 22 | "1. [Introduction](#section1')\n",
|
23 |
| - "2. [Generate Demo Dataset](#section2')\n", |
| 23 | + "2. [Generate Counterfactual Dataset](#section2')<br>\n", |
| 24 | + " 2.1 [Check fairness through unawareness](#section2-1')<br>\n", |
| 25 | + " 2.2 [Generate counterfactual responses](#section2-2')\n", |
24 | 26 | "3. [Assessment](#section3')<br>\n",
|
25 | 27 | " 3.1 [Lazy Implementation](#section3-1')<br>\n",
|
26 | 28 | " 3.2 [Separate Implementation](#section3-2')\n",
|
|
120 | 122 | "metadata": {},
|
121 | 123 | "source": [
|
122 | 124 | "<a id='section2'></a>\n",
|
123 |
| - "## 2. Generate Demo Dataset" |
| 125 | + "## 2. Generate Counterfactual Dataset" |
124 | 126 | ]
|
125 | 127 | },
|
126 | 128 | {
|
|
163 | 165 | "tags": []
|
164 | 166 | },
|
165 | 167 | "source": [
|
166 |
| - "### Counterfactual Dataset Generator\n", |
| 168 | + "#### Counterfactual Dataset Generator\n", |
167 | 169 | "***\n",
|
168 |
| - "##### `CounterfactualGenerator()` - Class for generating data for counterfactual discrimination assessment (class)\n", |
| 170 | + "##### `CounterfactualGenerator()` - Used for generating data for counterfactual fairness assessment (class)\n", |
169 | 171 | "\n",
|
170 | 172 | "**Class Attributes:**\n",
|
171 | 173 | "\n",
|
172 |
| - "- `langchain_llm` (**langchain llm (Runnable), default=None**) A langchain llm object to get passed to LLMChain `llm` argument. \n", |
| 174 | + "- `langchain_llm` (**langchain llm (Runnable), default=None**) A LangChain llm object to get passed to LangChain `RunnableSequence`. \n", |
173 | 175 | "- `suppressed_exceptions` (**tuple, default=None**) Specifies which exceptions to handle as 'Unable to get response' rather than raising the exception\n",
|
174 |
| - "- `max_calls_per_min` (**deprecated as of 0.2.0**) Use LangChain's InMemoryRateLimiter instead.\n", |
175 |
| - "\n", |
176 |
| - "**Methods:**\n", |
177 |
| - "\n", |
178 |
| - "1. `parse_texts()` - Parses a list of texts for protected attribute words and names\n", |
179 |
| - "\n", |
180 |
| - " **Method Parameters:**\n", |
181 |
| - "\n", |
182 |
| - " - `text` - (**string**) A text corpus to be parsed for protected attribute words and names\n", |
183 |
| - " - `attribute` - (**{'race','gender','name'}**) Specifies what to parse for among race words, gender words, and names\n", |
184 |
| - " - `custom_list` - (**List[str], default=None**) Custom list of tokens to use for parsing prompts. Must be provided if attribute is None.\n", |
185 |
| - " \n", |
186 |
| - " **Returns:**\n", |
187 |
| - " - list of results containing protected attribute words found (**list**)\n", |
188 |
| - "\n", |
189 |
| - "2. `create_prompts()` - Creates counterfactual prompts by counterfactual substitution\n", |
190 |
| - "\n", |
191 |
| - " **Method Parameters:**\n", |
192 |
| - "\n", |
193 |
| - " - `prompts` - (**List of strings**) A list of prompts on which counterfactual substitution and response generation will be done\n", |
194 |
| - " - `attribute` - (**{'gender', 'race'}, default=None**) Specifies what to parse for among race words and gender words. Must be specified if custom_list is None.\n", |
195 |
| - " - `custom_dict` - (**Dict[str, List[str]], default=None**) A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys should correspond to groups. Must be provided if attribute is None. For example: {'male': ['he', 'him', 'woman'], 'female': ['she', 'her', 'man']}\n", |
196 |
| - " subset_prompts : bool, default=True\n", |
197 |
| - " \n", |
198 |
| - " **Returns:**\n", |
199 |
| - " - list of prompts on which counterfactual substitution was completed (**list**)\n", |
200 |
| - " \n", |
201 |
| - "3. `neutralize_tokens()` - Neutralize gender and race words contained in a list of texts. Replaces gender words with a gender-neutral equivalent and race words with \"[MASK]\".\n", |
202 |
| - "\n", |
203 |
| - " **Method Parameters:**\n", |
204 |
| - "\n", |
205 |
| - " - `text_list` - (**List of strings**) A list of texts on which gender or race neutralization will occur\n", |
206 |
| - " - `attribute` - (**{'gender', 'race'}, default='gender'**) Specifies whether to use race or gender for for neutralization\n", |
207 |
| - "\n", |
208 |
| - " **Returns:**\n", |
209 |
| - " - list of texts neutralized with respect to race or gender (**list**)\n", |
210 |
| - "\n", |
211 |
| - "4. `generate_responses()` - Creates counterfactual prompts obtained by counterfactual substitution and generates responses asynchronously. \n", |
212 |
| - "\n", |
213 |
| - " **Method Parameters:**\n", |
214 |
| - "\n", |
215 |
| - " - `prompts` - (**List of strings**) A list of prompts on which counterfactual substitution and response generation will be done\n", |
216 |
| - " - `attribute` - (**{'gender', 'race'}, default='gender'**) Specifies whether to use race or gender for counterfactual substitution\n", |
217 |
| - " - `system_prompt` - (**str, default=\"You are a helpful assistant.\"**) Specifies system prompt for generation \n", |
218 |
| - " - `count` - (**int, default=25**) Specifies number of responses to generate for each prompt.\n", |
219 |
| - " - `custom_dict` - (**Dict[str, List[str]], default=None**) A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys should correspond to groups. Must be provided if attribute is None. For example: {'male': ['he', 'him', 'woman'], 'female': ['she', 'her', 'man']}\n", |
220 |
| - "\n", |
221 |
| - " **Returns:** A dictionary with two keys: `data` and `metadata`.\n", |
222 |
| - " - `data` (**dict**) A dictionary containing the prompts and responses.\n", |
223 |
| - " - `metadata` (**dict**) A dictionary containing metadata about the generation process, including non-completion rate, temperature, count, original prompts, and identified proctected attribute words." |
| 176 | + "- `max_calls_per_min` (**deprecated as of 0.2.0**) Use LangChain's InMemoryRateLimiter instead." |
224 | 177 | ]
|
225 | 178 | },
|
226 | 179 | {
|
|
366 | 319 | "cell_type": "markdown",
|
367 | 320 | "metadata": {},
|
368 | 321 | "source": [
|
369 |
| - "For illustration, this notebook assesses with 'race' as the protected attribute, but metrics can be evaluated for 'gender' or other custom protected attributes in the same way. First, the above mentioned `parse_texts` method is used to identify the input prompts that contain protected attribute words. \n", |
| 322 | + "<a id='section2-1'></a>\n", |
| 323 | + "### 2.1 Check fairness through unawareness" |
| 324 | + ] |
| 325 | + }, |
| 326 | + { |
| 327 | + "cell_type": "markdown", |
| 328 | + "metadata": {}, |
| 329 | + "source": [ |
| 330 | + "#### `CounterfactualGenerator.check_ftu()` - Parses prompts to check for fairness through unawareness. Returns dictionary with prompts, corresponding attribute words found, and applicable metadata. \n", |
| 331 | + "\n", |
| 332 | + "**Method Parameters:**\n", |
| 333 | + "\n", |
| 334 | + "- `text` - (**string**) A text corpus to be parsed for protected attribute words and names\n", |
| 335 | + "- `attribute` - (**{'race','gender','name'}**) Specifies what to parse for among race words, gender words, and names\n", |
| 336 | + "- `custom_list` - (**List[str], default=None**) Custom list of tokens to use for parsing prompts. Must be provided if attribute is None.\n", |
| 337 | + "- `subset_prompts` - (**bool, default=True**) Indicates whether to return all prompts or only those containing attribute words\n", |
| 338 | + "\n", |
| 339 | + "**Returns:**\n", |
| 340 | + "- dictionary with prompts, corresponding attribute words found, and applicable metadata (**dict**)" |
| 341 | + ] |
| 342 | + }, |
| 343 | + { |
| 344 | + "cell_type": "markdown", |
| 345 | + "metadata": {}, |
| 346 | + "source": [ |
| 347 | + "For illustration, this notebook assesses with 'race' as the protected attribute, but metrics can be evaluated for 'gender' or other custom protected attributes in the same way. First, the above mentioned `check_ftu` method is used to check for fairness through unawareness, i.e. whether prompts contain mentions of protected attribute words. In the returned object, prompts are subset to retain only those that contain protected attribute words. \n", |
370 | 348 | "\n",
|
371 | 349 | "Note: We recommend using atleast 1000 prompts that contain protected attribute words for better estimates. Otherwise, increase `count` attribute of `CounterfactualGenerator` class generate more responses."
|
372 | 350 | ]
|
|
456 | 434 | ],
|
457 | 435 | "source": [
|
458 | 436 | "# Check for fairness through unawareness\n",
|
459 |
| - "attribute = 'race'\n", |
460 |
| - "df = pd.DataFrame({'prompt': prompts})\n", |
461 |
| - "df[attribute + '_words'] = cdg.parse_texts(texts=prompts, attribute=attribute)\n", |
462 |
| - "\n", |
463 |
| - "# Remove input prompts that doesn't include a race word\n", |
464 |
| - "race_prompts = df[df['race_words'].apply(lambda x: len(x) > 0)][['prompt','race_words']]\n", |
465 |
| - "print(f\"Race words found in {len(race_prompts)} prompts\")\n", |
| 437 | + "ftu_result = cdg.check_ftu(\n", |
| 438 | + " prompts=prompts,\n", |
| 439 | + " attribute='race',\n", |
| 440 | + " subset_prompts=True\n", |
| 441 | + ")\n", |
| 442 | + "race_prompts = pd.DataFrame(ftu_result[\"data\"]).rename(columns={'attribute_words': 'race_words'})\n", |
466 | 443 | "race_prompts.tail(5)"
|
467 | 444 | ]
|
468 | 445 | },
|
469 | 446 | {
|
470 | 447 | "cell_type": "markdown",
|
471 | 448 | "metadata": {},
|
472 | 449 | "source": [
|
473 |
| - "Generate the model response on the input prompts using `generate_responses` method." |
| 450 | + "As seen above, this use case does not satisfy fairness through unawareness, since 246 prompts contain mentions of race words." |
| 451 | + ] |
| 452 | + }, |
| 453 | + { |
| 454 | + "cell_type": "markdown", |
| 455 | + "metadata": {}, |
| 456 | + "source": [ |
| 457 | + "<a id='section2-2'></a>\n", |
| 458 | + "### 2.2 Generate counterfactual responses" |
| 459 | + ] |
| 460 | + }, |
| 461 | + { |
| 462 | + "cell_type": "markdown", |
| 463 | + "metadata": {}, |
| 464 | + "source": [ |
| 465 | + "#### `CounterfactualGenerator.generate_responses()` - Creates counterfactual prompts obtained by counterfactual substitution and generates responses asynchronously. \n", |
| 466 | + "\n", |
| 467 | + "**Method Parameters:**\n", |
| 468 | + "\n", |
| 469 | + "- `prompts` - (**List of strings**) A list of prompts on which counterfactual substitution and response generation will be done\n", |
| 470 | + "- `attribute` - (**{'gender', 'race'}, default='gender'**) Specifies whether to use race or gender for counterfactual substitution\n", |
| 471 | + "- `system_prompt` - (**str, default=\"You are a helpful assistant.\"**) Specifies system prompt for generation \n", |
| 472 | + "- `count` - (**int, default=25**) Specifies number of responses to generate for each prompt.\n", |
| 473 | + "- `custom_dict` - (**Dict[str, List[str]], default=None**) A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys should correspond to groups. Must be provided if attribute is None. For example: {'male': ['he', 'him', 'woman'], 'female': ['she', 'her', 'man']}\n", |
| 474 | + "\n", |
| 475 | + "**Returns:** A dictionary with two keys: `data` and `metadata`.\n", |
| 476 | + "- `data` (**dict**) A dictionary containing the prompts and responses.\n", |
| 477 | + "- `metadata` (**dict**) A dictionary containing metadata about the generation process, including non-completion rate, temperature, count, original prompts, and identified proctected attribute words." |
| 478 | + ] |
| 479 | + }, |
| 480 | + { |
| 481 | + "cell_type": "markdown", |
| 482 | + "metadata": {}, |
| 483 | + "source": [ |
| 484 | + "Create counterfactual input prompts and generate corresponding LLM responses using `generate_responses` method." |
474 | 485 | ]
|
475 | 486 | },
|
476 | 487 | {
|
|
566 | 577 | ],
|
567 | 578 | "source": [
|
568 | 579 | "generations = await cdg.generate_responses(\n",
|
569 |
| - " prompts=df['prompt'], attribute='race', count=1\n", |
| 580 | + " prompts=race_prompts['prompt'], attribute='race', count=1\n", |
570 | 581 | ")\n",
|
571 | 582 | "output_df = pd.DataFrame(generations['data'])\n",
|
572 | 583 | "output_df.head(1)"
|
|
617 | 628 | "cell_type": "markdown",
|
618 | 629 | "metadata": {},
|
619 | 630 | "source": [
|
620 |
| - "### `CounterfactualMetrics()` - Calculate all the counterfactual metrics (class)\n", |
| 631 | + "#### `CounterfactualMetrics()` - Calculate all the counterfactual metrics (class)\n", |
621 | 632 | "**Class Attributes:**\n",
|
622 | 633 | "- `metrics` - (**List of strings/Metric objects**) Specifies which metrics to use.\n",
|
623 | 634 | "Default option is a list if strings (`metrics` = [\"Cosine\", \"Rougel\", \"Bleu\", \"Sentiment Bias\"]).\n",
|
|
1206 | 1217 | "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m125"
|
1207 | 1218 | },
|
1208 | 1219 | "kernelspec": {
|
1209 |
| - "display_name": "langchain", |
| 1220 | + "display_name": ".venv", |
1210 | 1221 | "language": "python",
|
1211 |
| - "name": "langchain" |
| 1222 | + "name": "python3" |
1212 | 1223 | },
|
1213 | 1224 | "language_info": {
|
1214 | 1225 | "codemirror_mode": {
|
|
1220 | 1231 | "name": "python",
|
1221 | 1232 | "nbconvert_exporter": "python",
|
1222 | 1233 | "pygments_lexer": "ipython3",
|
1223 |
| - "version": "3.11.10" |
| 1234 | + "version": "3.9.6" |
1224 | 1235 | }
|
1225 | 1236 | },
|
1226 | 1237 | "nbformat": 4,
|
|
0 commit comments