diff --git a/tutorials/distributed_data_classification/distributed_data_classification.ipynb b/tutorials/distributed_data_classification/distributed_data_classification.ipynb
index b0fec862c..81c7a6811 100644
--- a/tutorials/distributed_data_classification/distributed_data_classification.ipynb
+++ b/tutorials/distributed_data_classification/distributed_data_classification.ipynb
@@ -4,11 +4,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Distributed Data Classification with Quality and Domain Classifiers\n",
+ "# Distributed Data Classification with Domain and Quality Classifiers\n",
"\n",
- "The notebook demonstrates the use of two classifiers for distributed data classification, including quality and domain classifiers. The quality classifier is used to classify the quality of the data, while the domain classifier is used to classify the domain of the data. These classifers help with annotation which helps data blending for foundation model training. \n",
+ "The notebook demonstrates the use of two classifiers for distributed data classification, including domain and quality classifiers. The domain classifier is used to classify the domain of the data, while the quality classifier is used to classify the quality of the data. These classifers help with annotation which helps data blending for foundation model training.\n",
"\n",
- "The classifiers are accelerated using CrossFit,(https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets."
+ "The classifiers are accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets."
]
},
{
@@ -25,7 +25,7 @@
}
],
"source": [
- "#### Silence Warnings (HuggingFace internal warnings)\n",
+ "# Silence Warnings (HuggingFace internal warnings)\n",
"\n",
"%env PYTHONWARNINGS=ignore\n",
"import warnings\n",
@@ -41,7 +41,9 @@
"from dask_cuda import LocalCUDACluster\n",
"from dask.distributed import Client\n",
"from nemo_curator import DomainClassifier, QualityClassifier\n",
- "from nemo_curator.datasets import DocumentDataset"
+ "from nemo_curator.datasets import DocumentDataset\n",
+ "import cudf\n",
+ "import dask_cudf"
]
},
{
@@ -58,16 +60,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Define the data file paths "
+ "# Set File Paths "
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "input_file_path=\"/input_data_dir/\"\n",
"output_file_path = \"output_data_dir/\"\n",
"domain_model_path = \"domain_model.pth\"\n",
"quality_model_path = \"quality_model.pth\""
@@ -86,79 +87,91 @@
"metadata": {},
"outputs": [],
"source": [
- "classifier_type=\"DomainClassifier\" # or \"QualityClassifier\""
+ "classifier_type = \"DomainClassifier\" # or \"QualityClassifier\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Reading 16 files\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 10.5 s, sys: 5.33 s, total: 15.8 s\n",
- "Wall time: 11.4 s\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "%%time\n",
- "\n",
- "input_dataset = DocumentDataset.read_json(\n",
- " input_file_path, backend=\"cudf\", add_filename=True\n",
- ")\n",
+ "# Create sample DataFrame\n",
+ "text = [\n",
+ " \"Quantum computing is set to revolutionize the field of cryptography.\",\n",
+ " \"Investing in index funds is a popular strategy for long-term financial growth.\",\n",
+ " \"Recent advancements in gene therapy offer new hope for treating genetic disorders.\",\n",
+ " \"Online learning platforms have transformed the way students access educational resources.\",\n",
+ " \"Traveling to Europe during the off-season can be a more budget-friendly option.\",\n",
+ " \"Training regimens for athletes have become more sophisticated with the use of data analytics.\",\n",
+ " \"Streaming services are changing the way people consume television and film content.\",\n",
+ " \"Vegan recipes have gained popularity as more people adopt plant-based diets.\",\n",
+ " \"Climate change research is critical for developing sustainable environmental policies.\",\n",
+ " \"Telemedicine has become increasingly popular due to its convenience and accessibility.\",\n",
+ "]\n",
+ "df = cudf.DataFrame({\"text\": text})\n",
+ "input_dataset = DocumentDataset(dask_cudf.from_cudf(df, npartitions=1))\n",
+ "write_to_filename = False\n",
"\n",
+ "# Alternatively, read existing directory of JSONL files\n",
+ "# input_file_path=\"/input_data_dir/\"\n",
+ "# input_dataset = DocumentDataset.read_json(\n",
+ "# input_file_path, backend=\"cudf\", add_filename=True\n",
+ "# )\n",
+ "# write_to_filename = True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"if classifier_type == \"DomainClassifier\":\n",
" domain_labels = [\n",
- " \"Adult\",\n",
- " \"Arts_and_Entertainment\",\n",
- " \"Autos_and_Vehicles\",\n",
- " \"Beauty_and_Fitness\",\n",
- " \"Books_and_Literature\",\n",
- " \"Business_and_Industrial\",\n",
- " \"Computers_and_Electronics\",\n",
- " \"Finance\",\n",
- " \"Food_and_Drink\",\n",
- " \"Games\",\n",
- " \"Health\",\n",
- " \"Hobbies_and_Leisure\",\n",
- " \"Home_and_Garden\",\n",
- " \"Internet_and_Telecom\",\n",
- " \"Jobs_and_Education\",\n",
- " \"Law_and_Government\",\n",
- " \"News\",\n",
- " \"Online_Communities\",\n",
- " \"People_and_Society\",\n",
- " \"Pets_and_Animals\",\n",
- " \"Real_Estate\",\n",
- " \"Science\",\n",
- " \"Sensitive_Subjects\",\n",
- " \"Shopping\",\n",
- " \"Sports\",\n",
- " \"Travel_and_Transportation\",\n",
+ " \"Adult\",\n",
+ " \"Arts_and_Entertainment\",\n",
+ " \"Autos_and_Vehicles\",\n",
+ " \"Beauty_and_Fitness\",\n",
+ " \"Books_and_Literature\",\n",
+ " \"Business_and_Industrial\",\n",
+ " \"Computers_and_Electronics\",\n",
+ " \"Finance\",\n",
+ " \"Food_and_Drink\",\n",
+ " \"Games\",\n",
+ " \"Health\",\n",
+ " \"Hobbies_and_Leisure\",\n",
+ " \"Home_and_Garden\",\n",
+ " \"Internet_and_Telecom\",\n",
+ " \"Jobs_and_Education\",\n",
+ " \"Law_and_Government\",\n",
+ " \"News\",\n",
+ " \"Online_Communities\",\n",
+ " \"People_and_Society\",\n",
+ " \"Pets_and_Animals\",\n",
+ " \"Real_Estate\",\n",
+ " \"Science\",\n",
+ " \"Sensitive_Subjects\",\n",
+ " \"Shopping\",\n",
+ " \"Sports\",\n",
+ " \"Travel_and_Transportation\",\n",
" ]\n",
+ "\n",
" classifier = DomainClassifier(\n",
" model_path=domain_model_path,\n",
" labels=domain_labels,\n",
" batch_size=1024,\n",
" )\n",
+ "\n",
"elif classifier_type == \"QualityClassifier\":\n",
" quality_labels = [\"High\", \"Medium\", \"Low\"]\n",
- " model_file_name = \"quality_classifier.pth\"\n",
+ "\n",
" classifier = QualityClassifier(\n",
" model_path=quality_model_path,\n",
" labels=quality_labels,\n",
" batch_size=1024,\n",
" )\n",
+ "\n",
"else:\n",
" raise ValueError(\"Invalid classifier type\")"
]
@@ -188,31 +201,23 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "GPU: 0, Part: 1: 100%|██████████| 938/938 [00:09<00:00, 101.99it/s] \n",
- "GPU: 0, Part: 3: 100%|██████████| 938/938 [00:10<00:00, 92.36it/s] ]\n",
- "GPU: 0, Part: 0: 100%|██████████| 938/938 [00:10<00:00, 91.25it/s] ]\n",
- "GPU: 0, Part: 5: 100%|██████████| 938/938 [00:10<00:00, 88.82it/s] \n",
- "GPU: 0, Part: 14: 100%|██████████| 937/937 [00:10<00:00, 88.11it/s] \n",
- "GPU: 0, Part: 8: 100%|██████████| 937/937 [00:10<00:00, 85.46it/s] ]\n",
- "GPU: 0, Part: 9: 100%|██████████| 937/937 [00:10<00:00, 86.16it/s] \n",
- "GPU: 0, Part: 4: 100%|██████████| 938/938 [00:10<00:00, 85.65it/s]]\n",
- "GPU: 0, Part: 11: 100%|██████████| 937/937 [00:11<00:00, 83.73it/s] \n",
- "GPU: 0, Part: 6: 100%|██████████| 938/938 [00:11<00:00, 83.62it/s]\n",
- "GPU: 0, Part: 10: 100%|██████████| 937/937 [00:11<00:00, 81.27it/s] \n",
- "GPU: 0, Part: 2: 100%|██████████| 938/938 [00:12<00:00, 72.59it/s]]\n",
- "GPU: 0, Part: 7: 100%|██████████| 937/937 [00:13<00:00, 71.75it/s]\n",
- "GPU: 0, Part: 12: 100%|██████████| 937/937 [00:13<00:00, 69.12it/s]\n",
- "GPU: 0, Part: 15: 100%|██████████| 937/937 [00:13<00:00, 68.47it/s]\n",
- "GPU: 0, Part: 13: 100%|██████████| 937/937 [00:14<00:00, 66.29it/s]\n"
+ "GPU: 0, Part: 0: 100%|██████████| 10/10 [00:02<00:00, 3.62it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Writing to disk complete for 16 partitions\n",
- "CPU times: user 2.34 s, sys: 2.24 s, total: 4.58 s\n",
- "Wall time: 17.2 s\n"
+ "Writing to disk complete for 1 partitions\n",
+ "CPU times: user 578 ms, sys: 429 ms, total: 1.01 s\n",
+ "Wall time: 9.91 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU: 0, Part: 0: 100%|██████████| 10/10 [00:03<00:00, 3.30it/s]\n"
]
}
],
@@ -220,14 +225,14 @@
"%%time\n",
"\n",
"result_dataset = classifier(dataset=input_dataset)\n",
- "result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)"
+ "result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=write_to_filename)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Inspect the Output"
+ "# Inspect the Output"
]
},
{
@@ -239,7 +244,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Reading 16 files\n"
+ "Reading 1 files\n"
]
},
{
@@ -263,66 +268,54 @@
" \n",
" \n",
" \n",
" \n",
" \n",
- " adlr_id \n",
" domain_pred \n",
- " filename \n",
- " id \n",
- " pred \n",
- " source_id \n",
- " split_id \n",
" text \n",
- " url \n",
"