|  | 
|  | 1 | +# coding=utf-8 | 
|  | 2 | +# Copyright 2018 The Tensor2Tensor Authors. | 
|  | 3 | +# | 
|  | 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 5 | +# you may not use this file except in compliance with the License. | 
|  | 6 | +# You may obtain a copy of the License at | 
|  | 7 | +# | 
|  | 8 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 9 | +# | 
|  | 10 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 11 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 13 | +# See the License for the specific language governing permissions and | 
|  | 14 | +# limitations under the License. | 
|  | 15 | + | 
|  | 16 | +"""Data generators for CoNLL dataset.""" | 
|  | 17 | + | 
|  | 18 | +from __future__ import absolute_import | 
|  | 19 | +from __future__ import division | 
|  | 20 | +from __future__ import print_function | 
|  | 21 | + | 
|  | 22 | +import os | 
|  | 23 | +import zipfile | 
|  | 24 | + | 
|  | 25 | +from tensor2tensor.data_generators import generator_utils | 
|  | 26 | +from tensor2tensor.data_generators import problem | 
|  | 27 | +from tensor2tensor.data_generators import text_problems | 
|  | 28 | +from tensor2tensor.utils import registry | 
|  | 29 | +import tensorflow as tf | 
|  | 30 | + | 
|  | 31 | + | 
|  | 32 | +@registry.register_problem | 
|  | 33 | +class Conll2002Ner(text_problems.Text2textTmpdir): | 
|  | 34 | +  """Base class for CoNLL2002 problems.""" | 
|  | 35 | + | 
|  | 36 | +  def source_data_files(self, dataset_split): | 
|  | 37 | +    """Files to be passed to generate_samples.""" | 
|  | 38 | +    raise NotImplementedError() | 
|  | 39 | + | 
|  | 40 | +  def generate_samples(self, data_dir, tmp_dir, dataset_split): | 
|  | 41 | +    del data_dir | 
|  | 42 | + | 
|  | 43 | +    url = "https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/conll2002.zip"  # pylint: disable=line-too-long | 
|  | 44 | +    compressed_filename = os.path.basename(url) | 
|  | 45 | +    compressed_filepath = os.path.join(tmp_dir, compressed_filename) | 
|  | 46 | +    generator_utils.maybe_download(tmp_dir, compressed_filename, url) | 
|  | 47 | + | 
|  | 48 | +    compressed_dir = compressed_filepath.strip(".zip") | 
|  | 49 | + | 
|  | 50 | +    filenames = self.source_data_files(dataset_split) | 
|  | 51 | +    for filename in filenames: | 
|  | 52 | +      filepath = os.path.join(compressed_dir, filename) | 
|  | 53 | +      if not tf.gfile.Exists(filepath): | 
|  | 54 | +        with zipfile.ZipFile(compressed_filepath, "r") as corpus_zip: | 
|  | 55 | +          corpus_zip.extractall(tmp_dir) | 
|  | 56 | +      with tf.gfile.GFile(filepath, mode="r") as cur_file: | 
|  | 57 | +        words, tags = [], [] | 
|  | 58 | +        for line in cur_file: | 
|  | 59 | +          line_split = line.strip().split() | 
|  | 60 | +          if not line_split: | 
|  | 61 | +            yield { | 
|  | 62 | +                "inputs": str.join(" ", words), | 
|  | 63 | +                "targets": str.join(" ", tags) | 
|  | 64 | +            } | 
|  | 65 | +            words, tags = [], [] | 
|  | 66 | +            continue | 
|  | 67 | +          words.append(line_split[0]) | 
|  | 68 | +          tags.append(line_split[2]) | 
|  | 69 | +        if words: | 
|  | 70 | +          yield {"inputs": str.join(" ", words), "targets": str.join(" ", tags)} | 
|  | 71 | + | 
|  | 72 | + | 
|  | 73 | +@registry.register_problem | 
|  | 74 | +class Conll2002EsNer(Conll2002Ner): | 
|  | 75 | +  """Problem spec for CoNLL2002 Spanish named entity task.""" | 
|  | 76 | +  TRAIN_FILES = ["esp.train"] | 
|  | 77 | +  EVAL_FILES = ["esp.testa", "esp.testb"] | 
|  | 78 | + | 
|  | 79 | +  def source_data_files(self, dataset_split): | 
|  | 80 | +    is_training = dataset_split == problem.DatasetSplit.TRAIN | 
|  | 81 | +    return self.TRAIN_FILES if is_training else self.EVAL_FILES | 
|  | 82 | + | 
|  | 83 | + | 
|  | 84 | +@registry.register_problem | 
|  | 85 | +class Conll2002NlNer(Conll2002Ner): | 
|  | 86 | +  """Problem spec for CoNLL2002 Dutch named entity task.""" | 
|  | 87 | +  TRAIN_FILES = ["ned.train"] | 
|  | 88 | +  EVAL_FILES = ["ned.testa", "ned.testb"] | 
|  | 89 | + | 
|  | 90 | +  def source_data_files(self, dataset_split): | 
|  | 91 | +    is_training = dataset_split == problem.DatasetSplit.TRAIN | 
|  | 92 | +    return self.TRAIN_FILES if is_training else self.EVAL_FILES | 
0 commit comments