Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[WIP] filter API #741

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 131 additions & 4 deletions src/gluonnlp/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# pylint: disable=undefined-all-variable
"""NLP Toolkit Dataset API. It allows easy and customizable loading of corpora and dataset files.
Files can be loaded into formats that are immediately ready for training and evaluation."""
__all__ = ['TextLineDataset', 'CorpusDataset', 'ConcatDataset', 'TSVDataset', 'NumpyDataset']
__all__ = ['TextLineDataset', 'CorpusDataset', 'ConcatDataset', 'TSVDataset', 'NumpyDataset',
'Filter', 'RangeFilter', 'SplitFilter']

import io
import os
Expand Down Expand Up @@ -52,6 +53,126 @@ def __getitem__(self, i):
def __len__(self):
return self.cum_sizes[-1]

class Filter(object):
"""Abstract Filter class for Dataset creation.

A Filter can be used during Dataset construction time to reduce the number
of samples to hold in memory (and subsequent transformations).
"""
def __call__(self, index, data):
"""Check if the data sample passes the filter.

Parameters
----------
index : int
The original dataset index before filtering is applied.
sample : object
The original data sample object at the provided index.
"""
raise NotImplementedError()

class RangeFilter(Filter):
"""RangeFilter filters the data samples based on the range [start_idx, end_idx)
from the dataset. Only data samples within the range passes the filter.

Parameters
----------
start_idx : int
The start index (included).
end_idx : int or None
The end index (excluded). If set to None, it is set to infinity.

Example
-------
>>> data = "a,b,c\n"
>>> data += "d,e,f\n"
>>> data += "g,h,i\n"
>>> data += "j,k,l\n"
>>> data += "m,n,o\n"
>>> with open('test_range_filter.txt', 'w') as fout:
>>> fout.write(data)
>>>
>>> # create 2 partitions, and read partition 0 only
>>> filter_fn = nlp.data.RangeFilter(1, 3)
>>> dataset = nlp.data.TextLineDataset('test_range_filter.txt', filter_fn=filter_fn)
>>> len(dataset)
2
>>> dataset[0]
"d,e,f"
>>> dataset[1]
"g,h,i"
"""
def __init__(self, start_idx, end_idx):
self.start = start_idx
self.end = end_idx
if end_idx is not None:
assert self.start < self.end, 'end_idx must be greater than start_idx'

def __call__(self, index, data):
"""Check if the data sample passes the filter.

Parameters
----------
index : int
The original dataset index before filtering is applied.
sample : object
The original data sample object at the provided index.
"""
if self.end is not None:
return index >= self.start and index < self.end
else:
return index >= self.start

class SplitFilter(Filter):
"""SplitFilter filters the data samples based on the number of partitions
and partition index of the dataset. Only data samples for the target
partition index passes the filter.

Parameters
----------
num_parts : int
The number of partitions.
part_idx : int
The target partition index that will pass the filter.

Example
-------
>>> data = "a,b,c\n"
>>> data += "d,e,f\n"
>>> data += "g,h,i\n"
>>> data += "j,k,l\n"
>>> data += "m,n,o\n"
>>> with open('test_split_filter.txt', 'w') as fout:
>>> fout.write(data)
>>>
>>> # create 2 partitions, and read partition 0 only
>>> filter_fn = nlp.data.SplitFilter(2, 0)
>>> dataset = nlp.data.TextLineDataset('test_split_filter.txt', filter_fn=filter_fn)
>>> len(dataset)
3
>>> dataset[0]
"a,b,c"
>>> dataset[1]
"g,h,i"
>>> dataset[2]
"m,n,o"
"""
def __init__(self, num_parts, part_idx):
self.num_parts = num_parts
self.part_idx = part_idx
assert self.part_idx < self.num_parts, 'part_idx should be less than num_parts'

def __call__(self, index, data):
"""Check if the data sample passes the filter.

Parameters
----------
index : int
The original dataset index before filtering is applied.
sample : object
The original data sample object at the provided index.
"""
return index % self.num_parts == self.part_idx

class TextLineDataset(SimpleDataset):
"""Dataset that comprises lines in a file. Each line will be stripped.
Expand All @@ -62,12 +183,18 @@ class TextLineDataset(SimpleDataset):
Path to the input text file.
encoding : str, default 'utf8'
File encoding format.
filter_fn : Filter, default None
Filter function to decide if the line should be kept or discarded.
"""
def __init__(self, filename, encoding='utf8'):
def __init__(self, filename, encoding='utf8', filter_fn=None):
lines = []
with io.open(filename, 'r', encoding=encoding) as in_file:
for line in in_file:
lines.append(line.strip())
for idx, line in enumerate(in_file):
line = line.strip()
# check if the line should be filtered out
if filter_fn is not None and not filter_fn(idx, line):
continue
lines.append(line)
super(TextLineDataset, self).__init__(lines)


Expand Down
44 changes: 44 additions & 0 deletions tests/unittest/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,47 @@ def test_numpy_dataset():
assert np.all(dataset[1][1] == b[1])
dataset_b = dataset.get_field('b')
assert np.all(dataset_b == b)

def test_split_filter():
data = "a,b,c\n"
data += "d,e,f\n"
data += "g,h,i\n"
data += "j,k,l\n"
data += "m,n,o\n"
with open('test_split_filter.txt', 'w') as fout:
fout.write(data)
filter_fn = nlp.data.SplitFilter(2, 0)
dataset = nlp.data.TextLineDataset('test_split_filter.txt', filter_fn=filter_fn)
assert len(dataset) == 3
assert dataset[0] == "a,b,c"
assert dataset[1] == "g,h,i"
assert dataset[2] == "m,n,o"

filter_fn = nlp.data.SplitFilter(2, 1)
dataset = nlp.data.TextLineDataset('test_split_filter.txt', filter_fn=filter_fn)
assert len(dataset) == 2
assert dataset[0] == "d,e,f"
assert dataset[1] == "j,k,l"

def test_range_filter():
data = "a,b,c\n"
data += "d,e,f\n"
data += "g,h,i\n"
data += "j,k,l\n"
data += "m,n,o\n"
with open('test_range_filter.txt', 'w') as fout:
fout.write(data)
filter_fn = nlp.data.RangeFilter(1, 3)
dataset = nlp.data.TextLineDataset('test_range_filter.txt', filter_fn=filter_fn)
assert len(dataset) == 2
assert dataset[0] == "d,e,f"
assert dataset[1] == "g,h,i"

filter_fn = nlp.data.RangeFilter(3, None)
dataset = nlp.data.TextLineDataset('test_range_filter.txt', filter_fn=filter_fn)
assert len(dataset) == 2, len(dataset)
assert dataset[0] == "j,k,l"
assert dataset[1] == "m,n,o"

test_split_filter()
test_range_filter()