Skip to content

Commit dbe7606

Browse files
ryantwolfnicoleeeluo
authored andcommitted
Improve speed of AddId module (NVIDIA#36)
* Add fast id method Signed-off-by: Ryan Wolf <[email protected]> * Add type conversion Signed-off-by: Ryan Wolf <[email protected]> * Fix off by one errors in tests Signed-off-by: Ryan Wolf <[email protected]> --------- Signed-off-by: Ryan Wolf <[email protected]> Signed-off-by: Nicole Luo <[email protected]>
1 parent f297076 commit dbe7606

File tree

4 files changed

+94
-12
lines changed

4 files changed

+94
-12
lines changed

nemo_curator/modules/add_id.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,58 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
16+
1517
import dask.dataframe as dd
1618
import numpy as np
1719
from dask import delayed
1820

1921
from nemo_curator.datasets import DocumentDataset
22+
from nemo_curator.utils.module_utils import count_digits
2023

2124

2225
class AddId:
23-
def __init__(self, id_field, id_prefix="doc_id", start_index=0) -> None:
26+
def __init__(
27+
self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None
28+
) -> None:
2429
self.id_field = id_field
2530
self.id_prefix = id_prefix
2631
self.start_index = start_index
2732

2833
def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
34+
if self.start_index is None:
35+
return self._add_id_fast(dataset)
36+
else:
37+
return self._add_id_ordered(dataset)
38+
39+
def _add_id_fast(self, dataset: DocumentDataset) -> DocumentDataset:
40+
meta = dataset.df.dtypes.to_dict()
41+
meta[self.id_field] = "string"
42+
43+
partition_zero_padding = count_digits(dataset.df.npartitions)
44+
id_df = dataset.df.map_partitions(
45+
self._add_id_fast_partition,
46+
partition_zero_padding,
47+
meta=meta,
48+
)
49+
50+
return DocumentDataset(id_df)
51+
52+
def _add_id_fast_partition(self, partition, global_padding, partition_info=None):
53+
local_padding = count_digits(len(partition))
54+
global_id = partition_info["number"]
55+
56+
id_column = [
57+
f"{self.id_prefix}-{local_id:0{local_padding}d}{global_id:0{global_padding}d}"
58+
for local_id in range(len(partition))
59+
]
60+
partition[self.id_field] = id_column
61+
62+
return partition
63+
64+
def _add_id_ordered(self, dataset: DocumentDataset) -> DocumentDataset:
2965
original_meta = dataset.df.dtypes.to_dict()
30-
original_meta[self.id_field] = "object"
66+
original_meta[self.id_field] = "string"
3167
delayed_dataset = dataset.df.to_delayed()
3268

3369
parition_lengths = [0]
@@ -38,7 +74,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
3874
delayed_id_dataset = []
3975
for i, partition in enumerate(delayed_dataset):
4076
delayed_id_dataset.append(
41-
delayed(self._add_id_to_partition)(partition, lower_id_bounds[i])
77+
delayed(self._add_id_ordered_partition)(partition, lower_id_bounds[i])
4278
)
4379

4480
id_dataset = DocumentDataset(
@@ -47,11 +83,12 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
4783

4884
return id_dataset
4985

50-
def _add_id_to_partition(self, partition, partition_start_id):
86+
def _add_id_ordered_partition(self, partition, partition_start_id):
5187
id_column = [
5288
f"{self.id_prefix}-{int(i + self.start_index):010d}"
5389
for i in range(partition_start_id, len(partition) + partition_start_id)
5490
]
5591
partition[self.id_field] = id_column
92+
partition[self.id_field] = partition[self.id_field].astype("string")
5693

5794
return partition

nemo_curator/scripts/add_id.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ def attach_args(
7979
parser.add_argument(
8080
"--starting-index",
8181
type=int,
82-
default=0,
83-
help="Starting index from which to start indexing the documents",
82+
default=None,
83+
help="If supplied, determines the starting index from which to start "
84+
"indexing the documents. By default, it is unspecified, and uses an id"
85+
" scheme that is fast to calculate and is not guaranteed to be ordered.",
8486
)
8587
parser.add_argument(
8688
"--output-data-dir",

nemo_curator/utils/module_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415

1516

1617
def is_batched(function):
1718
return hasattr(function, "batched") and function.batched
19+
20+
21+
def count_digits(num):
22+
return math.floor(math.log10(num)) + 1

tests/test_add_id.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pandas as pd
1717
import pytest
1818

19-
import nemo_curator
19+
import nemo_curator as nc
2020
from nemo_curator.datasets import DocumentDataset
2121

2222

@@ -41,10 +41,10 @@ def two_partition_dataset():
4141
)
4242

4343

44-
class TestPrepareTaskData:
44+
class TestAddId:
4545
def test_basic_id(self, single_partition_dataset):
4646
id_field = "id"
47-
add_id = nemo_curator.AddId(id_field)
47+
add_id = nc.AddId(id_field, start_index=0)
4848
id_dataset = add_id(single_partition_dataset)
4949
actual_ids = id_dataset.df[id_field].compute()
5050
expected_ids = pd.Series(
@@ -63,7 +63,7 @@ def test_basic_id(self, single_partition_dataset):
6363

6464
def test_two_partitions(self, two_partition_dataset):
6565
id_field = "id"
66-
add_id = nemo_curator.AddId(id_field)
66+
add_id = nc.AddId(id_field, start_index=0)
6767
id_dataset = add_id(two_partition_dataset)
6868
actual_ids = id_dataset.df[id_field].compute()
6969
expected_ids = pd.Series(
@@ -83,7 +83,7 @@ def test_two_partitions(self, two_partition_dataset):
8383
def test_id_prefix(self, two_partition_dataset):
8484
id_field = "id"
8585
id_prefix = "my_id"
86-
add_id = nemo_curator.AddId(id_field, id_prefix=id_prefix)
86+
add_id = nc.AddId(id_field, id_prefix=id_prefix, start_index=0)
8787
id_dataset = add_id(two_partition_dataset)
8888
actual_ids = id_dataset.df[id_field].compute()
8989
expected_ids = pd.Series(
@@ -103,7 +103,7 @@ def test_id_prefix(self, two_partition_dataset):
103103
def test_start_index(self, two_partition_dataset):
104104
id_field = "id"
105105
start_index = 13
106-
add_id = nemo_curator.AddId(id_field, start_index=start_index)
106+
add_id = nc.AddId(id_field, start_index=start_index)
107107
id_dataset = add_id(two_partition_dataset)
108108
actual_ids = id_dataset.df[id_field].compute()
109109
expected_ids = pd.Series(
@@ -119,3 +119,41 @@ def test_start_index(self, two_partition_dataset):
119119
assert all(
120120
expected_ids == actual_ids
121121
), f"Expected: {expected_ids}, got: {actual_ids}"
122+
123+
def test_fast_id_single_partition(self, single_partition_dataset):
124+
id_field = "id"
125+
add_id = nc.AddId(id_field)
126+
id_dataset = add_id(single_partition_dataset)
127+
actual_ids = id_dataset.df[id_field].compute()
128+
expected_ids = pd.Series(
129+
[
130+
"doc_id-00",
131+
"doc_id-10",
132+
"doc_id-20",
133+
"doc_id-30",
134+
"doc_id-40",
135+
]
136+
)
137+
138+
assert all(
139+
expected_ids == actual_ids
140+
), f"Expected: {expected_ids}, got: {actual_ids}"
141+
142+
def test_fast_id_two_partitions(self, two_partition_dataset):
143+
id_field = "id"
144+
add_id = nc.AddId(id_field)
145+
id_dataset = add_id(two_partition_dataset)
146+
actual_ids = id_dataset.df[id_field].compute()
147+
expected_ids = pd.Series(
148+
[
149+
"doc_id-00",
150+
"doc_id-10",
151+
"doc_id-20",
152+
"doc_id-01",
153+
"doc_id-11",
154+
]
155+
)
156+
157+
assert all(
158+
expected_ids == actual_ids
159+
), f"Expected: {expected_ids}, got: {actual_ids}"

0 commit comments

Comments
 (0)