Skip to content

Commit f4606fd

Browse files
MaghoumiVibhuJawa
authored andcommitted
[Tutorials] Add a tutorial for PEFT data curation (NVIDIA#45)
This PR adds a new tutorial to demonstrate data curation for PEFT use-cases. Signed-off-by: Mehran Maghoumi <[email protected]> Signed-off-by: Vibhu Jawa <[email protected]>
1 parent 5387ccd commit f4606fd

File tree

7 files changed

+432
-2
lines changed

7 files changed

+432
-2
lines changed

tutorials/peft-curation/README.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Curating Datasets for Parameter Efficient Fine-tuning
2+
3+
This tutorial demonstrates the usage of NeMo Curator's Python API to curate a dataset for
4+
parameter-efficient fine-tuning (PEFT).
5+
6+
In this tutorial, we use the [Enron Emails dataset](https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning),
7+
which is a dataset of emails with corresponding classification labels for each email. Each email has
8+
a subject, a body and a category (class label). We demonstrate various filtering and processing
9+
operations that can be applied to each record.
10+
11+
## Usage
12+
After installing the NeMo Curator package, you can simply run the following command:
13+
```
14+
python tutorials/peft-curation/main.py
15+
```
16+
17+
By default, this tutorial will use at most 8 workers to run the curation pipeline. If you face any
18+
out of memory issues, you can reduce the number of workers by supplying the `--n-workers=N` argument,
19+
where `N` is the number of workers to spawn.

tutorials/peft-curation/docbuilder.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import re
17+
from typing import Dict
18+
19+
import requests
20+
21+
from nemo_curator.download.doc_builder import (
22+
DocumentDownloader,
23+
DocumentExtractor,
24+
DocumentIterator,
25+
)
26+
27+
28+
class EmailsDownloader(DocumentDownloader):
29+
def __init__(self, download_dir: str):
30+
super().__init__()
31+
32+
if not os.path.isdir(download_dir):
33+
os.makedirs(download_dir)
34+
35+
self._download_dir = download_dir
36+
print("Download directory: ", self._download_dir)
37+
38+
def download(self, url: str) -> str:
39+
filename = os.path.basename(url)
40+
output_file = os.path.join(self._download_dir, filename)
41+
42+
if os.path.exists(output_file):
43+
print(f"File '{output_file}' already exists, skipping download.")
44+
return output_file
45+
46+
print(f"Downloading Enron emails dataset from '{url}'...")
47+
response = requests.get(url)
48+
49+
with open(output_file, "wb") as file:
50+
file.write(response.content)
51+
52+
return output_file
53+
54+
55+
class EmailsIterator(DocumentIterator):
56+
57+
def __init__(self):
58+
super().__init__()
59+
self._counter = -1
60+
self._extractor = EmailsExtractor()
61+
# The regular expression pattern to extract each email.
62+
self._pattern = re.compile(r"\"<s>.*?<s>\"", re.DOTALL)
63+
64+
def iterate(self, file_path):
65+
self._counter = -1
66+
file_name = os.path.basename(file_path)
67+
68+
with open(file_path, "r", encoding="utf-8") as file:
69+
lines = file.readlines()
70+
71+
# Ignore the first line which contains the header.
72+
file_content = "".join(lines[1:])
73+
# Find all the emails in the file.
74+
it = self._pattern.finditer(file_content)
75+
76+
for email in it:
77+
self._counter += 1
78+
content = email.group().strip('"').strip()
79+
meta = {
80+
"filename": file_name,
81+
"id": f"email-{self._counter}",
82+
}
83+
extracted_content = self._extractor.extract(content)
84+
85+
# Skip if no content extracted
86+
if not extracted_content:
87+
continue
88+
89+
record = {**meta, **extracted_content}
90+
yield record
91+
92+
93+
class EmailsExtractor(DocumentExtractor):
94+
def __init__(self):
95+
super().__init__()
96+
# The regular expression pattern to extract subject/body/label into groups.
97+
self._pattern = re.compile(
98+
r"Subject:: (.*?)\nBody:: (.*?)\n.*\[/INST\] (.*?) <s>", re.DOTALL
99+
)
100+
101+
def extract(self, content: str) -> Dict[str, str]:
102+
matches = self._pattern.findall(content)
103+
104+
if not matches:
105+
return None
106+
107+
matches = matches[0]
108+
109+
return {
110+
"subject": matches[0].strip(),
111+
"body": matches[1].strip(),
112+
"category": matches[2].strip(),
113+
}

tutorials/peft-curation/filters.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nemo_curator.filters import DocumentFilter
16+
17+
18+
class FilterEmailsWithLongBody(DocumentFilter):
19+
"""
20+
If the email is too long, discard.
21+
"""
22+
23+
def __init__(self, max_length: int = 5000):
24+
super().__init__()
25+
self.max_length = max_length
26+
27+
def score_document(self, text: str) -> bool:
28+
return len(text) <= self.max_length
29+
30+
def keep_document(self, score) -> bool:
31+
return score
32+
33+
34+
class FilterEmptyEmails(DocumentFilter):
35+
"""
36+
Detects empty emails (either empty body, or labeled as empty). Returns `True` for empty emails.
37+
"""
38+
39+
def score_document(self, text: str) -> bool:
40+
return (
41+
not isinstance(text, str) # The text is not a string
42+
or len(text.strip()) == 0 # The text is empty
43+
or "Empty message" in text # The email is labeled as empty
44+
)
45+
46+
def keep_document(self, score) -> bool:
47+
return score

tutorials/peft-curation/main.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import json
17+
import os
18+
from functools import partial
19+
from typing import Any
20+
21+
from docbuilder import EmailsDownloader, EmailsIterator
22+
from filters import FilterEmailsWithLongBody, FilterEmptyEmails
23+
from modifiers import AddPeriod, AddSystemPrompt
24+
25+
from nemo_curator import ScoreFilter, Sequential
26+
from nemo_curator.datasets import DocumentDataset
27+
from nemo_curator.modifiers.pii_modifier import PiiModifier
28+
from nemo_curator.modifiers.unicode_reformatter import UnicodeReformatter
29+
from nemo_curator.modules.modify import Modify
30+
from nemo_curator.utils.distributed_utils import get_client
31+
from nemo_curator.utils.script_utils import add_distributed_args
32+
33+
SCRIPT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
34+
DATA_DIR = os.path.join(SCRIPT_DIR_PATH, "data")
35+
DATASET_URL = "https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning/raw/main/prompts_train.csv"
36+
37+
38+
def download_and_convert_to_jsonl() -> str:
39+
"""
40+
Downloads the emails dataset and converts it to JSONL format.
41+
42+
Returns:
43+
str: The path to the JSONL file.
44+
"""
45+
46+
# Download the dataset in raw format and convert it to JSONL.
47+
downloader = EmailsDownloader(DATA_DIR)
48+
output_path = os.path.join(DATA_DIR, "emails.jsonl")
49+
raw_fp = downloader.download(DATASET_URL)
50+
51+
iterator = EmailsIterator()
52+
53+
# Parse the raw data and write it to a JSONL file.
54+
with open(output_path, "w") as f:
55+
for record in iterator.iterate(raw_fp):
56+
json_record = json.dumps(record, ensure_ascii=False)
57+
f.write(json_record + "\n")
58+
59+
return output_path
60+
61+
62+
def redact_pii(dataset: DocumentDataset, text_field) -> DocumentDataset:
63+
"""
64+
Redacts personally identifiable information (PII) from a given dataset.
65+
66+
Args:
67+
dataset (DocumentDataset): The dataset containing documents with PII.
68+
69+
Returns:
70+
DocumentDataset: The redacted dataset with PII replaced by a generic value.
71+
"""
72+
redactor = Modify(
73+
PiiModifier(
74+
supported_entities=[
75+
"ADDRESS",
76+
"EMAIL_ADDRESS",
77+
"LOCATION",
78+
"PERSON",
79+
"URL",
80+
"PHONE_NUMBER",
81+
],
82+
anonymize_action="replace",
83+
device="cpu",
84+
),
85+
text_field=text_field,
86+
)
87+
return redactor(dataset)
88+
89+
90+
def run_curation_pipeline(args: Any, jsonl_fp: str) -> str:
91+
"""
92+
Run the curation pipeline on the dataset.
93+
94+
Args:
95+
args (Any): Command-line arguments.
96+
jsonl_fp (str): The path to the uncurated JSONL file.
97+
98+
Returns:
99+
str: The path to the curated JSONL file.
100+
"""
101+
client = get_client(args, args.device)
102+
print(f" Running the curation pipeline on '{jsonl_fp}'...")
103+
orig_dataset = DocumentDataset.read_json(jsonl_fp, add_filename=True)
104+
dataset = orig_dataset
105+
106+
redact_pii_subject = partial(redact_pii, text_field="subject")
107+
redact_pii_body = partial(redact_pii, text_field="body")
108+
109+
curation_steps = Sequential(
110+
[
111+
#
112+
# Unify the text encoding to Unicode.
113+
#
114+
Modify(UnicodeReformatter(), text_field="subject"),
115+
Modify(UnicodeReformatter(), text_field="body"),
116+
Modify(UnicodeReformatter(), text_field="category"),
117+
#
118+
# Filtering
119+
#
120+
# Filter out empty emails.
121+
ScoreFilter(
122+
FilterEmptyEmails(), text_field="subject", score_type=bool, invert=True
123+
),
124+
ScoreFilter(
125+
FilterEmptyEmails(), text_field="body", score_type=bool, invert=True
126+
),
127+
ScoreFilter(
128+
FilterEmptyEmails(), text_field="category", score_type=bool, invert=True
129+
),
130+
# Filter out emails that are too long.
131+
ScoreFilter(FilterEmailsWithLongBody(), text_field="body", score_type=bool),
132+
#
133+
# Redact personally identifiable information (PII).
134+
#
135+
redact_pii_subject,
136+
redact_pii_body,
137+
#
138+
# Final modifications.
139+
#
140+
# Add system prompts to every email, which helps the model focus on the task.
141+
Modify(AddSystemPrompt(), text_field="body"),
142+
# Add a period to the end of each email category, which makes PEFT easier.
143+
Modify(AddPeriod(), text_field="category"),
144+
]
145+
)
146+
147+
dataset = curation_steps(dataset)
148+
dataset = dataset.persist()
149+
150+
print(f" Original dataset length: {len(orig_dataset.df)}")
151+
print(f" After running the curation pipeline: {len(dataset.df)}")
152+
print(f" Writing to '{jsonl_fp}'...")
153+
out_path = os.path.join(
154+
os.path.dirname(jsonl_fp),
155+
"curated",
156+
)
157+
os.makedirs(out_path, exist_ok=True)
158+
dataset.to_json(out_path, write_to_filename=True)
159+
client.close()
160+
return os.path.join(out_path, os.path.basename(jsonl_fp))
161+
162+
163+
def main():
164+
parser = argparse.ArgumentParser()
165+
parser = add_distributed_args(parser)
166+
args = parser.parse_args()
167+
# Limit the total number of workers to ensure we don't run out of memory.
168+
args.n_workers = min(args.n_workers, 8)
169+
170+
# Prepare the download and JSONL directories.
171+
if not os.path.isdir(DATA_DIR):
172+
os.makedirs(DATA_DIR)
173+
174+
jsonl_fp = download_and_convert_to_jsonl()
175+
run_curation_pipeline(args, jsonl_fp)
176+
177+
178+
if __name__ == "__main__":
179+
main()

0 commit comments

Comments
 (0)