Skip to content

Commit

Permalink
Add dtype support (optional) when reading jsonl files
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Martínez <[email protected]>
Signed-off-by: Miguel Martínez <[email protected]>
  • Loading branch information
miguelusque committed May 22, 2024
1 parent 06ee061 commit 638f7ff
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 54 deletions.
13 changes: 12 additions & 1 deletion examples/blend_and_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def main(args):
client = get_client(args, args.device)

# Blend the datasets
datasets = [DocumentDataset.read_json(path) for path in dataset_paths]
datasets = [
DocumentDataset.read_json(path, input_meta=args.input_meta)
for path in dataset_paths
]
blended_dataset = nc.blend_datasets(target_size, datasets, dataset_weights)

shuffle = nc.Shuffle(seed=42)
Expand All @@ -46,6 +49,14 @@ def attach_args(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A dictionary containing the json object field names and their "
"corresponding data types.",
)

return add_distributed_args(parser)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(args):
client = get_client(args, cluster_type=args.device)

input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
input_file_path, backend="cudf", add_filename=True, input_meta=args.input_meta
)

domain_classifier = DomainClassifier(
Expand Down Expand Up @@ -134,6 +134,13 @@ def attach_args(
default="gpu",
help="Device to run the script on. Either 'cpu' or 'gpu'.",
)
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A dictionary containing the json object field names and their "
"corresponding data types.",
)

return parser

Expand Down
12 changes: 11 additions & 1 deletion examples/exact_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def main(args):
client.run(pre_imports)

t0 = time.time()
input_dataset = DocumentDataset.read_json(dataset_dir, backend=backend)
input_dataset = DocumentDataset.read_json(
dataset_dir, backend=backend, input_meta=args.input_meta
)

exact_dup = ExactDuplicates(
logger=log_dir,
Expand Down Expand Up @@ -79,6 +81,14 @@ def attach_args(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A dictionary containing the json object field names and their "
"corresponding data types.",
)

return add_distributed_args(parser)


Expand Down
11 changes: 9 additions & 2 deletions examples/fuzzy_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def main(args):
)
elif filetype == "jsonl":
input_dataset = DocumentDataset.read_json(
dataset_dir,
backend=backend,
dataset_dir, backend=backend, input_meta=args.input_meta
)

fuzzy_dedup_config = FuzzyDuplicatesConfig(
Expand Down Expand Up @@ -102,6 +101,14 @@ def attach_args(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A dictionary containing the json object field names and their "
"corresponding data types.",
)

return add_distributed_args(parser)


Expand Down
24 changes: 20 additions & 4 deletions examples/identify_languages_and_fix_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
from nemo_curator.utils.script_utils import add_distributed_args


def load_dataset(input_data_dir):
def load_dataset(input_data_dir: str, input_meta: str = None):
files = list(get_all_files_paths_under(input_data_dir))
raw_data = read_data(files, file_type="jsonl", backend="pandas", add_filename=True)
raw_data = read_data(
files,
file_type="jsonl",
backend="pandas",
add_filename=True,
input_meta=input_meta,
)
dataset = DocumentDataset(raw_data)

return dataset
Expand All @@ -52,7 +58,9 @@ def main(args):
client = get_client(args, args.device)

# Filter data
multilingual_dataset = load_dataset(multilingual_data_path)
multilingual_dataset = load_dataset(
input_data_dir=multilingual_data_path, input_meta=args.input_meta
)
language_id_pipeline = nc.ScoreFilter(
FastTextLangId(model_path), score_field=language_field, score_type="object"
)
Expand All @@ -74,7 +82,7 @@ def main(args):
lang_data_path = os.path.join(language_separated_output_path, target_language)
if not os.path.exists(lang_data_path):
raise RuntimeError(f"Dataset did not have language: {target_language}")
lang_data = load_dataset(lang_data_path)
lang_data = load_dataset(input_data_dir=lang_data_path, input_meta=args.input_meta)

cleaner = nc.Modify(UnicodeReformatter())
cleaned_data = cleaner(lang_data)
Expand All @@ -88,6 +96,14 @@ def attach_args(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A dictionary containing the json object field names and their "
"corresponding data types.",
)

return add_distributed_args(parser)


Expand Down
42 changes: 25 additions & 17 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import dask.dataframe as dd

from nemo_curator.utils.distributed_utils import read_data, write_to_disk
Expand All @@ -36,10 +38,11 @@ def persist(self):
@classmethod
def read_json(
cls,
input_files,
backend="pandas",
files_per_partition=1,
add_filename=False,
input_files: Union[str, List[str]],
backend: str = "pandas",
files_per_partition: int = 1,
add_filename: bool = False,
input_meta: str = None,
):
return cls(
_read_json_or_parquet(
Expand All @@ -48,6 +51,7 @@ def read_json(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)
)

Expand Down Expand Up @@ -77,16 +81,16 @@ def read_pickle(
files_per_partition=1,
add_filename=False,
):
raw_data = read_data(
input_files=input_files,
file_type="pickle",
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
return cls(
read_data(
input_files=input_files,
file_type="pickle",
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
)
)

return cls(raw_data)

def to_json(
self,
output_file_dir,
Expand Down Expand Up @@ -128,11 +132,12 @@ def to_pickle(


def _read_json_or_parquet(
input_files,
file_type,
backend,
files_per_partition,
add_filename,
input_files: Union[str, List[str]],
file_type: str,
backend: str,
files_per_partition: int,
add_filename: bool,
input_meta: str = None,
):
"""
`input_files` may be a list or a string type.
Expand Down Expand Up @@ -162,6 +167,7 @@ def _read_json_or_parquet(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)

# List of directories
Expand All @@ -178,6 +184,7 @@ def _read_json_or_parquet(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)
dfs.append(df)

Expand All @@ -200,6 +207,7 @@ def _read_json_or_parquet(
backend=backend,
files_per_partition=files_per_partition,
add_filename=add_filename,
input_meta=input_meta,
)

else:
Expand Down
38 changes: 28 additions & 10 deletions nemo_curator/distributed_data_classification/verify_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
import ast
import os

import pandas as pd
Expand All @@ -27,30 +28,39 @@ def parse_args():
"""
parser = argparse.ArgumentParser(description="Run verification")

parser.add_argument(
"--results_file_path",
type=str,
help="The path of the input files",
required=True,
help="The path of the input files",
)
parser.add_argument(
"--expected_results_file_path",
type=str,
help="The path of the expected_result file",
required=True,
help="The path of the expected_result file",
)
parser.add_argument(
"--results_pred_column",
type=str,
help="The prediction column name for the input files",
default="pred",
help="The prediction column name for the input files",
)
parser.add_argument(
"--expected_pred_column",
type=str,
help="The prediction column name for the expected_result file",
default="pred",
help="The prediction column name for the expected_result file",
)
parser.add_argument(
"--input-meta",
type=str,
default=None,
help="A dictionary containing the json object field names and their "
"corresponding data types.",
)

return parser.parse_args()


Expand Down Expand Up @@ -122,10 +132,11 @@ def verify_same_dataframe(


def verify_results(
results_file_path,
expected_results_file_path,
results_pred_column,
expected_pred_column,
results_file_path: str,
expected_results_file_path: str,
results_pred_column: str,
expected_pred_column: str,
input_meta: str = None,
):
"""
This function compares an input file with its expected result file.
Expand All @@ -138,7 +149,10 @@ def verify_results(
expected_pred_column: The prediction column name for the expected_result file.
"""
expected_df = pd.read_json(expected_results_file_path, lines=True)
if input_meta:
input_meta = ast.literal_eval(input_meta)

expected_df = pd.read_json(expected_results_file_path, lines=True, dtype=input_meta)
expected_df = expected_df.sort_values(by=["text"]).reset_index(drop=True)
expected_counts = expected_df[expected_pred_column].value_counts().to_dict()

Expand All @@ -150,7 +164,10 @@ def verify_results(
]

got_paths = [p for p in os.scandir(results_file_path)]
got_df = [pd.read_json(path, lines=True)[expected_columns] for path in got_paths]
got_df = [
pd.read_json(path, lines=True, dtype=input_meta)[expected_columns]
for path in got_paths
]
got_df = pd.concat(got_df, ignore_index=True)
got_df = got_df.sort_values(by=["text"]).reset_index(drop=True)
got_counts = got_df[results_pred_column].value_counts().to_dict()
Expand All @@ -172,6 +189,7 @@ def main():
args.expected_results_file_path,
args.results_pred_column,
args.expected_pred_column,
args.input_meta,
)


Expand Down
8 changes: 6 additions & 2 deletions nemo_curator/download/doc_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import importlib
import os
from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import Dict, List, Tuple

import dask.dataframe as dd
import pandas as pd
Expand Down Expand Up @@ -111,6 +111,7 @@ def _download_and_extract_single_partition(
output_type: str,
keep_raw_download: bool,
force_download: bool,
input_meta: str = None,
) -> pd.DataFrame:
url, output_path = paths

Expand Down Expand Up @@ -158,6 +159,7 @@ def download_and_extract(
output_type: str = "jsonl",
keep_raw_download=False,
force_download=False,
input_meta: str = None,
) -> DocumentDataset:
"""
Downloads and extracts a dataset into a format accepted by the NeMo Curator
Expand All @@ -174,6 +176,7 @@ def download_and_extract(
keep_raw_download: Whether to keep the pre-extracted download file.
force_download: If False, will skip processing all files in output_paths that already exist and
directly read from them instead.
input_meta: A dictionary with the json object field names and data types.
Returns:
A DocumentDataset of the downloaded data
Expand All @@ -190,8 +193,9 @@ def download_and_extract(
extractor=extractor,
output_type=output_type,
keep_raw_download=keep_raw_download,
force_download=force_download,
force_download=force_dssownload,
enforce_metadata=False,
input_meta=input_meta,
meta=output_format,
)

Expand Down
Loading

0 comments on commit 638f7ff

Please sign in to comment.