Skip to content

Commit be0f651

Browse files
yuwen-yankpe
authored andcommitted
add problems, conll2002_es_ner and conll2002_nl_ner (tensorflow#1253)
1 parent 6056a63 commit be0f651

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Data generators for CoNLL dataset."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import zipfile
24+
25+
from tensor2tensor.data_generators import generator_utils
26+
from tensor2tensor.data_generators import problem
27+
from tensor2tensor.data_generators import text_problems
28+
from tensor2tensor.utils import registry
29+
import tensorflow as tf
30+
31+
@registry.register_problem
32+
class Conll2002Ner(text_problems.Text2textTmpdir):
33+
"""Base class for CoNLL2002 problems."""
34+
def source_data_files(self, dataset_split):
35+
"""Files to be passed to generate_samples."""
36+
raise NotImplementedError()
37+
38+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
39+
del data_dir
40+
41+
url = 'https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/conll2002.zip' # pylint: disable=line-too-long
42+
compressed_filename = os.path.basename(url)
43+
compressed_filepath = os.path.join(tmp_dir, compressed_filename)
44+
generator_utils.maybe_download(tmp_dir, compressed_filename, url)
45+
46+
compressed_dir = compressed_filepath.strip(".zip")
47+
48+
filenames = self.source_data_files(dataset_split)
49+
for filename in filenames:
50+
filepath = os.path.join(compressed_dir, filename)
51+
if not tf.gfile.Exists(filepath):
52+
with zipfile.ZipFile(compressed_filepath, 'r') as corpus_zip:
53+
corpus_zip.extractall(tmp_dir)
54+
with tf.gfile.GFile(filepath, mode="r") as cur_file:
55+
words, tags = [], []
56+
for line in cur_file:
57+
line_split = line.strip().split()
58+
if len(line_split) == 0:
59+
yield {"inputs": str.join(" ", words),
60+
"targets": str.join(" ", tags)}
61+
words, tags = [], []
62+
continue
63+
words.append(line_split[0])
64+
tags.append(line_split[2])
65+
if len(words) != 0:
66+
yield {"inputs": str.join(" ", words), "targets": str.join(" ", tags)}
67+
68+
@registry.register_problem
69+
class Conll2002EsNer(Conll2002Ner):
70+
"""Problem spec for CoNLL2002 Spanish named entity task."""
71+
TRAIN_FILES = ["esp.train"]
72+
EVAL_FILES = ["esp.testa", "esp.testb"]
73+
def source_data_files(self, dataset_split):
74+
is_training = dataset_split == problem.DatasetSplit.TRAIN
75+
return self.TRAIN_FILES if is_training else self.EVAL_FILES
76+
77+
@registry.register_problem
78+
class Conll2002NlNer(Conll2002Ner):
79+
"""Problem spec for CoNLL2002 Dutch named entity task."""
80+
TRAIN_FILES = ["ned.train"]
81+
EVAL_FILES = ["ned.testa", "ned.testb"]
82+
def source_data_files(self, dataset_split):
83+
is_training = dataset_split == problem.DatasetSplit.TRAIN
84+
return self.TRAIN_FILES if is_training else self.EVAL_FILES

0 commit comments

Comments
 (0)