diff --git a/src/gluonnlp/data/dataset.py b/src/gluonnlp/data/dataset.py index 8b279dc838..6a7a2735fa 100644 --- a/src/gluonnlp/data/dataset.py +++ b/src/gluonnlp/data/dataset.py @@ -20,7 +20,8 @@ # pylint: disable=undefined-all-variable """NLP Toolkit Dataset API. It allows easy and customizable loading of corpora and dataset files. Files can be loaded into formats that are immediately ready for training and evaluation.""" -__all__ = ['TextLineDataset', 'CorpusDataset', 'ConcatDataset', 'TSVDataset', 'NumpyDataset'] +__all__ = ['TextLineDataset', 'CorpusDataset', 'ConcatDataset', 'TSVDataset', 'NumpyDataset', + 'Filter', 'RangeFilter', 'SplitFilter'] import io import os @@ -52,6 +53,126 @@ def __getitem__(self, i): def __len__(self): return self.cum_sizes[-1] +class Filter(object): + """Abstract Filter class for Dataset creation. + + A Filter can be used during Dataset construction time to reduce the number + of samples to hold in memory (and subsequent transformations). + """ + def __call__(self, index, data): + """Check if the data sample passes the filter. + + Parameters + ---------- + index : int + The original dataset index before filtering is applied. + sample : object + The original data sample object at the provided index. + """ + raise NotImplementedError() + +class RangeFilter(Filter): + """RangeFilter filters the data samples based on the range [start_idx, end_idx) + from the dataset. Only data samples within the range passes the filter. + + Parameters + ---------- + start_idx : int + The start index (included). + end_idx : int or None + The end index (excluded). If set to None, it is set to infinity. + + Example + ------- + >>> data = "a,b,c\n" + >>> data += "d,e,f\n" + >>> data += "g,h,i\n" + >>> data += "j,k,l\n" + >>> data += "m,n,o\n" + >>> with open('test_range_filter.txt', 'w') as fout: + >>> fout.write(data) + >>> + >>> # create 2 partitions, and read partition 0 only + >>> filter_fn = nlp.data.RangeFilter(1, 3) + >>> dataset = nlp.data.TextLineDataset('test_range_filter.txt', filter_fn=filter_fn) + >>> len(dataset) + 2 + >>> dataset[0] + "d,e,f" + >>> dataset[1] + "g,h,i" + """ + def __init__(self, start_idx, end_idx): + self.start = start_idx + self.end = end_idx + if end_idx is not None: + assert self.start < self.end, 'end_idx must be greater than start_idx' + + def __call__(self, index, data): + """Check if the data sample passes the filter. + + Parameters + ---------- + index : int + The original dataset index before filtering is applied. + sample : object + The original data sample object at the provided index. + """ + if self.end is not None: + return index >= self.start and index < self.end + else: + return index >= self.start + +class SplitFilter(Filter): + """SplitFilter filters the data samples based on the number of partitions + and partition index of the dataset. Only data samples for the target + partition index passes the filter. + + Parameters + ---------- + num_parts : int + The number of partitions. + part_idx : int + The target partition index that will pass the filter. + + Example + ------- + >>> data = "a,b,c\n" + >>> data += "d,e,f\n" + >>> data += "g,h,i\n" + >>> data += "j,k,l\n" + >>> data += "m,n,o\n" + >>> with open('test_split_filter.txt', 'w') as fout: + >>> fout.write(data) + >>> + >>> # create 2 partitions, and read partition 0 only + >>> filter_fn = nlp.data.SplitFilter(2, 0) + >>> dataset = nlp.data.TextLineDataset('test_split_filter.txt', filter_fn=filter_fn) + >>> len(dataset) + 3 + >>> dataset[0] + "a,b,c" + >>> dataset[1] + "g,h,i" + >>> dataset[2] + "m,n,o" + """ + def __init__(self, num_parts, part_idx): + self.num_parts = num_parts + self.part_idx = part_idx + assert self.part_idx < self.num_parts, 'part_idx should be less than num_parts' + + def __call__(self, index, data): + """Check if the data sample passes the filter. + + Parameters + ---------- + index : int + The original dataset index before filtering is applied. + sample : object + The original data sample object at the provided index. + """ + return index % self.num_parts == self.part_idx class TextLineDataset(SimpleDataset): """Dataset that comprises lines in a file. Each line will be stripped. @@ -62,12 +183,18 @@ class TextLineDataset(SimpleDataset): Path to the input text file. encoding : str, default 'utf8' File encoding format. + filter_fn : Filter, default None + Filter function to decide if the line should be kept or discarded. """ - def __init__(self, filename, encoding='utf8'): + def __init__(self, filename, encoding='utf8', filter_fn=None): lines = [] with io.open(filename, 'r', encoding=encoding) as in_file: - for line in in_file: - lines.append(line.strip()) + for idx, line in enumerate(in_file): + line = line.strip() + # check if the line should be filtered out + if filter_fn is not None and not filter_fn(idx, line): + continue + lines.append(line) super(TextLineDataset, self).__init__(lines) diff --git a/tests/unittest/test_datasets.py b/tests/unittest/test_datasets.py index a04542c614..951bc06acc 100644 --- a/tests/unittest/test_datasets.py +++ b/tests/unittest/test_datasets.py @@ -655,3 +655,47 @@ def test_numpy_dataset(): assert np.all(dataset[1][1] == b[1]) dataset_b = dataset.get_field('b') assert np.all(dataset_b == b) + +def test_split_filter(): + data = "a,b,c\n" + data += "d,e,f\n" + data += "g,h,i\n" + data += "j,k,l\n" + data += "m,n,o\n" + with open('test_split_filter.txt', 'w') as fout: + fout.write(data) + filter_fn = nlp.data.SplitFilter(2, 0) + dataset = nlp.data.TextLineDataset('test_split_filter.txt', filter_fn=filter_fn) + assert len(dataset) == 3 + assert dataset[0] == "a,b,c" + assert dataset[1] == "g,h,i" + assert dataset[2] == "m,n,o" + + filter_fn = nlp.data.SplitFilter(2, 1) + dataset = nlp.data.TextLineDataset('test_split_filter.txt', filter_fn=filter_fn) + assert len(dataset) == 2 + assert dataset[0] == "d,e,f" + assert dataset[1] == "j,k,l" + +def test_range_filter(): + data = "a,b,c\n" + data += "d,e,f\n" + data += "g,h,i\n" + data += "j,k,l\n" + data += "m,n,o\n" + with open('test_range_filter.txt', 'w') as fout: + fout.write(data) + filter_fn = nlp.data.RangeFilter(1, 3) + dataset = nlp.data.TextLineDataset('test_range_filter.txt', filter_fn=filter_fn) + assert len(dataset) == 2 + assert dataset[0] == "d,e,f" + assert dataset[1] == "g,h,i" + + filter_fn = nlp.data.RangeFilter(3, None) + dataset = nlp.data.TextLineDataset('test_range_filter.txt', filter_fn=filter_fn) + assert len(dataset) == 2, len(dataset) + assert dataset[0] == "j,k,l" + assert dataset[1] == "m,n,o" + +test_split_filter() +test_range_filter()