diff --git a/.github/workflows/issue-manager.yml b/.github/workflows/issue-manager.yml new file mode 100644 index 00000000..3fb42ed0 --- /dev/null +++ b/.github/workflows/issue-manager.yml @@ -0,0 +1,29 @@ +name: Issue Manager + +on: + schedule: + - cron: "0 0 * * *" + issue_comment: + types: + - created + - edited + issues: + types: + - labeled + +jobs: + issue-manager: + runs-on: ubuntu-latest + steps: + - uses: tiangolo/issue-manager@0.2.1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + config: > + { + "resolved": { + "delay": "P7D", + "message": "This issue has been automatically closed because it was answered and there was no follow-up discussion.", + "remove_label_on_comment": true, + "remove_label_on_close": true + } + } diff --git a/HISTORY.rst b/HISTORY.rst index cdc6463c..a0a8d647 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,6 +2,12 @@ History ================================================================================ +Current +-------------------------------------------------------------------------------- + +* Try to use Github Actions (GH-353) +* Copy dependece_parser module from supar (GH-157) + 1.2.3 (2020-11-28) -------------------------------------------------------------------------------- diff --git a/README.rst b/README.rst index 5de272b3..47d34482 100644 --- a/README.rst +++ b/README.rst @@ -29,6 +29,8 @@ Underthesea - Vietnamese NLP Toolkit **underthesea** is a suite of open source Python modules, data sets and tutorials supporting research and development in Vietnamese Natural Language Processing. +💫 **Version 1.3.0a0 out now!** `Underthesea meet deep learning! `_ + +-----------------+------------------------------------------------------------------------------------------------+ | Free software | GNU General Public License v3 | +-----------------+------------------------------------------------------------------------------------------------+ diff --git a/setup.py b/setup.py index 21d2f52c..b164312c 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,9 @@ 'scikit-learn>=0.20,<0.22', 'unidecode', 'seqeval', - 'PyYAML' + 'PyYAML', + 'torch>=1.1.0,<=1.5.1', + 'transformers>=3.5.0,<=3.5.1' ] tests_require = [ diff --git a/underthesea/VERSION b/underthesea/VERSION index 0495c4a8..a3f53673 100644 --- a/underthesea/VERSION +++ b/underthesea/VERSION @@ -1 +1 @@ -1.2.3 +1.3.0-alpha diff --git a/underthesea/__init__.py b/underthesea/__init__.py index 51fb3ff8..cf1aec5a 100644 --- a/underthesea/__init__.py +++ b/underthesea/__init__.py @@ -45,8 +45,43 @@ except Exception: pass +########################################################### +# Initialize +########################################################### +import torch +import logging.config + +# global variable: device for torch +device = None +if torch.cuda.is_available(): + device = torch.device("cuda:0") +else: + device = torch.device("cpu") + +logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": {"standard": {"format": "%(asctime)-15s %(message)s"}}, + "handlers": { + "console": { + "level": "INFO", + "class": "logging.StreamHandler", + "formatter": "standard", + "stream": "ext://sys.stdout", + } + }, + "loggers": { + "underthesea": {"handlers": ["console"], "level": "INFO", "propagate": False} + }, + } +) + +logger = logging.getLogger("underthesea") + __all__ = [ 'sent_tokenize', 'word_tokenize', 'pos_tag', 'chunk', 'ner', - 'classify', 'sentiment' + 'classify', 'sentiment', + 'logger', 'device' ] diff --git a/underthesea/data.py b/underthesea/data.py new file mode 100644 index 00000000..1bfc3744 --- /dev/null +++ b/underthesea/data.py @@ -0,0 +1,695 @@ +# -*- coding: utf-8 -*- +from collections.abc import Iterable +import nltk +from tqdm import tqdm +from underthesea.utils.sp_parallel import is_master + + +def progress_bar(iterator, + ncols=None, + bar_format='{l_bar}{bar:36}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}', + leave=True): + return tqdm(iterator, + ncols=ncols, + bar_format=bar_format, + ascii=True, + disable=(not is_master()), + leave=leave) + + +class Transform(object): + r""" + A Transform object corresponds to a specific data format. + It holds several instances of data fields that provide instructions for preprocessing and numericalizing, etc. + + Attributes: + training (bool): + Sets the object in training mode. + If ``False``, some data fields not required for predictions won't be returned. + Default: ``True``. + """ + + fields = [] + + def __init__(self): + self.training = True + + def __call__(self, sentences): + pairs = dict() + for field in self: + if field not in self.src and field not in self.tgt: + continue + if not self.training and field in self.tgt: + continue + if not isinstance(field, Iterable): + field = [field] + for f in field: + if f is not None: + pairs[f] = f.transform([getattr(i, f.name) for i in sentences]) + + return pairs + + def __getitem__(self, index): + return getattr(self, self.fields[index]) + + def train(self, training=True): + self.training = training + + def eval(self): + self.train(False) + + def append(self, field): + self.fields.append(field.name) + setattr(self, field.name, field) + + @property + def src(self): + raise AttributeError + + @property + def tgt(self): + raise AttributeError + + def save(self, path, sentences): + with open(path, 'w') as f: + f.write('\n'.join([str(i) for i in sentences]) + '\n') + + +class Sentence(object): + r""" + A Sentence object holds a sentence with regard to specific data format. + """ + + def __init__(self, transform): + self.transform = transform + + # mapping from each nested field to their proper position + self.maps = dict() + # names of each field + self.keys = set() + # values of each position + self.values = [] + for i, field in enumerate(self.transform): + if not isinstance(field, Iterable): + field = [field] + for f in field: + if f is not None: + self.maps[f.name] = i + self.keys.add(f.name) + + def __len__(self): + return len(self.values[0]) + + def __contains__(self, key): + return key in self.keys + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + return self.values[self.maps[name]] + + def __setattr__(self, name, value): + if 'keys' in self.__dict__ and name in self: + index = self.maps[name] + if index >= len(self.values): + self.__dict__[name] = value + else: + self.values[index] = value + else: + self.__dict__[name] = value + + def __getstate__(self): + return vars(self) + + def __setstate__(self, state): + self.__dict__.update(state) + + +class CoNLL(Transform): + r""" + The CoNLL object holds ten fields required for CoNLL-X data format. + Each field can be binded with one or more :class:`Field` objects. For example, + ``FORM`` can contain both :class:`Field` and :class:`SubwordField` to produce tensors for words and subwords. + + Attributes: + ID: + Token counter, starting at 1. + FORM: + Words in the sentence. + LEMMA: + Lemmas or stems (depending on the particular treebank) of words, or underscores if not available. + CPOS: + Coarse-grained part-of-speech tags, where the tagset depends on the treebank. + POS: + Fine-grained part-of-speech tags, where the tagset depends on the treebank. + FEATS: + Unordered set of syntactic and/or morphological features (depending on the particular treebank), + or underscores if not available. + HEAD: + Heads of the tokens, which are either values of ID or zeros. + DEPREL: + Dependency relations to the HEAD. + PHEAD: + Projective heads of tokens, which are either values of ID or zeros, or underscores if not available. + PDEPREL: + Dependency relations to the PHEAD, or underscores if not available. + + References: + - Sabine Buchholz and Erwin Marsi. 2006. + `CoNLL-X Shared Task on Multilingual Dependency Parsing`_. + + .. _CoNLL-X Shared Task on Multilingual Dependency Parsing: + https://www.aclweb.org/anthology/W06-2920/ + """ + + fields = ['ID', 'FORM', 'LEMMA', 'CPOS', 'POS', 'FEATS', 'HEAD', 'DEPREL', 'PHEAD', 'PDEPREL'] + + def __init__(self, + ID=None, FORM=None, LEMMA=None, CPOS=None, POS=None, + FEATS=None, HEAD=None, DEPREL=None, PHEAD=None, PDEPREL=None): + super().__init__() + + self.ID = ID + self.FORM = FORM + self.LEMMA = LEMMA + self.CPOS = CPOS + self.POS = POS + self.FEATS = FEATS + self.HEAD = HEAD + self.DEPREL = DEPREL + self.PHEAD = PHEAD + self.PDEPREL = PDEPREL + + @property + def src(self): + return self.FORM, self.CPOS + + @property + def tgt(self): + return self.HEAD, self.DEPREL + + @classmethod + def get_arcs(cls, sequence): + return [int(i) for i in sequence] + + @classmethod + def get_sibs(cls, sequence): + sibs = [-1] * (len(sequence) + 1) + heads = [0] + [int(i) for i in sequence] + + for i in range(1, len(heads)): + hi = heads[i] + for j in range(i + 1, len(heads)): + hj = heads[j] + di, dj = hi - i, hj - j + if hi >= 0 and hj >= 0 and hi == hj and di * dj > 0: + if abs(di) > abs(dj): + sibs[i] = j + else: + sibs[j] = i + break + return sibs[1:] + + @classmethod + def toconll(cls, tokens): + r""" + Converts a list of tokens to a string in CoNLL-X format. + Missing fields are filled with underscores. + + Args: + tokens (list[str] or list[tuple]): + This can be either a list of words or word/pos pairs. + + Returns: + A string in CoNLL-X format. + + Examples: + >>> print(CoNLL.toconll(['She', 'enjoys', 'playing', 'tennis', '.'])) + 1 She _ _ _ _ _ _ _ _ + 2 enjoys _ _ _ _ _ _ _ _ + 3 playing _ _ _ _ _ _ _ _ + 4 tennis _ _ _ _ _ _ _ _ + 5 . _ _ _ _ _ _ _ _ + + """ + + if isinstance(tokens[0], str): + s = '\n'.join([f"{i}\t{word}\t" + '\t'.join(['_'] * 8) + for i, word in enumerate(tokens, 1)]) + else: + s = '\n'.join([f"{i}\t{word}\t_\t{tag}\t" + '\t'.join(['_'] * 6) + for i, (word, tag) in enumerate(tokens, 1)]) + return s + '\n' + + @classmethod + def isprojective(cls, sequence): + r""" + Checks if a dependency tree is projective. + This also works for partial annotation. + + Besides the obvious crossing arcs, the examples below illustrate two non-projective cases + which are hard to detect in the scenario of partial annotation. + + Args: + sequence (list[int]): + A list of head indices. + + Returns: + ``True`` if the tree is projective, ``False`` otherwise. + + Examples: + >>> CoNLL.isprojective([2, -1, 1]) # -1 denotes un-annotated cases + False + >>> CoNLL.isprojective([3, -1, 2]) + False + """ + + pairs = [(h, d) for d, h in enumerate(sequence, 1) if h >= 0] + for i, (hi, di) in enumerate(pairs): + for hj, dj in pairs[i + 1:]: + (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj]) + if li <= hj <= ri and hi == dj: + return False + if lj <= hi <= rj and hj == di: + return False + if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0: + return False + return True + + @classmethod + def istree(cls, sequence, proj=False, multiroot=False): + r""" + Checks if the arcs form an valid dependency tree. + + Args: + sequence (list[int]): + A list of head indices. + proj (bool): + If ``True``, requires the tree to be projective. Default: ``False``. + multiroot (bool): + If ``False``, requires the tree to contain only a single root. Default: ``True``. + + Returns: + ``True`` if the arcs form an valid tree, ``False`` otherwise. + + Examples: + >>> CoNLL.istree([3, 0, 0, 3], multiroot=True) + True + >>> CoNLL.istree([3, 0, 0, 3], proj=True) + False + """ + + from underthesea.utils.sp_alg import tarjan + if proj and not cls.isprojective(sequence): + return False + n_roots = sum(head == 0 for head in sequence) + if n_roots == 0: + return False + if not multiroot and n_roots > 1: + return False + if any(i == head for i, head in enumerate(sequence, 1)): + return False + return next(tarjan(sequence), None) is None + + def load(self, data, proj=False, max_len=None, **kwargs): + r""" + Loads the data in CoNLL-X format. + Also supports for loading data from CoNLL-U file with comments and non-integer IDs. + + Args: + data (list[list] or str): + A list of instances or a filename. + proj (bool): + If ``True``, discards all non-projective sentences. Default: ``False``. + max_len (int): + Sentences exceeding the length will be discarded. Default: ``None``. + + Returns: + A list of :class:`CoNLLSentence` instances. + """ + + if isinstance(data, str): + with open(data, 'r') as f: + lines = [line.strip() for line in f] + else: + data = [data] if isinstance(data[0], str) else data + lines = '\n'.join([self.toconll(i) for i in data]).split('\n') + + i, start, sentences = 0, 0, [] + for line in progress_bar(lines, leave=False): + if not line: + sentences.append(CoNLLSentence(self, lines[start:i])) + start = i + 1 + i += 1 + if proj: + sentences = [i for i in sentences if self.isprojective(list(map(int, i.arcs)))] + if max_len is not None: + sentences = [i for i in sentences if len(i) < max_len] + + return sentences + + +class CoNLLSentence(Sentence): + r""" + Sencence in CoNLL-X format. + + Args: + transform (CoNLL): + A :class:`CoNLL` object. + lines (list[str]): + A list of strings composing a sentence in CoNLL-X format. + Comments and non-integer IDs are permitted. + + Examples: + >>> lines = ['# text = But I found the location wonderful and the neighbors very kind.', + '1\tBut\t_\t_\t_\t_\t_\t_\t_\t_', + '2\tI\t_\t_\t_\t_\t_\t_\t_\t_', + '3\tfound\t_\t_\t_\t_\t_\t_\t_\t_', + '4\tthe\t_\t_\t_\t_\t_\t_\t_\t_', + '5\tlocation\t_\t_\t_\t_\t_\t_\t_\t_', + '6\twonderful\t_\t_\t_\t_\t_\t_\t_\t_', + '7\tand\t_\t_\t_\t_\t_\t_\t_\t_', + '7.1\tfound\t_\t_\t_\t_\t_\t_\t_\t_', + '8\tthe\t_\t_\t_\t_\t_\t_\t_\t_', + '9\tneighbors\t_\t_\t_\t_\t_\t_\t_\t_', + '10\tvery\t_\t_\t_\t_\t_\t_\t_\t_', + '11\tkind\t_\t_\t_\t_\t_\t_\t_\t_', + '12\t.\t_\t_\t_\t_\t_\t_\t_\t_'] + >>> sentence = CoNLLSentence(transform, lines) # fields in transform are built from ptb. + >>> sentence.arcs = [3, 3, 0, 5, 6, 3, 6, 9, 11, 11, 6, 3] + >>> sentence.rels = ['cc', 'nsubj', 'root', 'det', 'nsubj', 'xcomp', + 'cc', 'det', 'dep', 'advmod', 'conj', 'punct'] + >>> sentence + # text = But I found the location wonderful and the neighbors very kind. + 1 But _ _ _ _ 3 cc _ _ + 2 I _ _ _ _ 3 nsubj _ _ + 3 found _ _ _ _ 0 root _ _ + 4 the _ _ _ _ 5 det _ _ + 5 location _ _ _ _ 6 nsubj _ _ + 6 wonderful _ _ _ _ 3 xcomp _ _ + 7 and _ _ _ _ 6 cc _ _ + 7.1 found _ _ _ _ _ _ _ _ + 8 the _ _ _ _ 9 det _ _ + 9 neighbors _ _ _ _ 11 dep _ _ + 10 very _ _ _ _ 11 advmod _ _ + 11 kind _ _ _ _ 6 conj _ _ + 12 . _ _ _ _ 3 punct _ _ + """ + + def __init__(self, transform, lines): + super().__init__(transform) + + self.values = [] + # record annotations for post-recovery + self.annotations = dict() + + for i, line in enumerate(lines): + value = line.split('\t') + if value[0].startswith('#') or not value[0].isdigit(): + self.annotations[-i - 1] = line + else: + self.annotations[len(self.values)] = line + self.values.append(value) + self.values = list(zip(*self.values)) + + def __repr__(self): + # cover the raw lines + merged = {**self.annotations, + **{i: '\t'.join(map(str, line)) + for i, line in enumerate(zip(*self.values))}} + return '\n'.join(merged.values()) + '\n' + + +class Tree(Transform): + r""" + The Tree object factorize a constituency tree into four fields, each associated with one or more :class:`Field` objects. + + Attributes: + WORD: + Words in the sentence. + POS: + Part-of-speech tags, or underscores if not available. + TREE: + The raw constituency tree in :class:`nltk.tree.Tree` format. + CHART: + The factorized sequence of binarized tree traversed in pre-order. + """ + + root = '' + fields = ['WORD', 'POS', 'TREE', 'CHART'] + + def __init__(self, WORD=None, POS=None, TREE=None, CHART=None): + super().__init__() + + self.WORD = WORD + self.POS = POS + self.TREE = TREE + self.CHART = CHART + + @property + def src(self): + return self.WORD, self.POS, self.TREE + + @property + def tgt(self): + return self.CHART, + + @classmethod + def totree(cls, tokens, root=''): + r""" + Converts a list of tokens to a :class:`nltk.tree.Tree`. + Missing fields are filled with underscores. + + Args: + tokens (list[str] or list[tuple]): + This can be either a list of words or word/pos pairs. + root (str): + The root label of the tree. Default: ''. + + Returns: + A :class:`nltk.tree.Tree` object. + + Examples: + >>> print(Tree.totree(['She', 'enjoys', 'playing', 'tennis', '.'], 'TOP')) + (TOP (_ She) (_ enjoys) (_ playing) (_ tennis) (_ .)) + """ + + if isinstance(tokens[0], str): + tokens = [(token, '_') for token in tokens] + tree = ' '.join([f"({pos} {word})" for word, pos in tokens]) + return nltk.Tree.fromstring(f"({root} {tree})") + + @classmethod + def binarize(cls, tree): + r""" + Conducts binarization over the tree. + + First, the tree is transformed to satisfy `Chomsky Normal Form (CNF)`_. + Here we call :meth:`~nltk.tree.Tree.chomsky_normal_form` to conduct left-binarization. + Second, all unary productions in the tree are collapsed. + + Args: + tree (nltk.tree.Tree): + The tree to be binarized. + + Returns: + The binarized tree. + + Examples: + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ She)) + (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis))))) + (_ .))) + ''') + >>> print(Tree.binarize(tree)) + (TOP + (S + (S|<> + (NP (_ She)) + (VP + (VP|<> (_ enjoys)) + (S+VP (VP|<> (_ playing)) (NP (_ tennis))))) + (S|<> (_ .)))) + + .. _Chomsky Normal Form (CNF): + https://en.wikipedia.org/wiki/Chomsky_normal_form + """ + + tree = tree.copy(True) + nodes = [tree] + while nodes: + node = nodes.pop() + if isinstance(node, nltk.Tree): + nodes.extend([child for child in node]) + if len(node) > 1: + for i, child in enumerate(node): + if not isinstance(child[0], nltk.Tree): + node[i] = nltk.Tree(f"{node.label()}|<>", [child]) + tree.chomsky_normal_form('left', 0, 0) + tree.collapse_unary() + + return tree + + @classmethod + def factorize(cls, tree, delete_labels=None, equal_labels=None): + r""" + Factorizes the tree into a sequence. + The tree is traversed in pre-order. + + Args: + tree (nltk.tree.Tree): + The tree to be factorized. + delete_labels (set[str]): + A set of labels to be ignored. This is used for evaluation. + If it is a pre-terminal label, delete the word along with the brackets. + If it is a non-terminal label, just delete the brackets (don't delete childrens). + In `EVALB`_, the default set is: + {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''} + Default: ``None``. + equal_labels (dict[str, str]): + The key-val pairs in the dict are considered equivalent (non-directional). This is used for evaluation. + The default dict defined in `EVALB`_ is: {'ADVP': 'PRT'} + Default: ``None``. + + Returns: + The sequence of the factorized tree. + + Examples: + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ She)) + (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis))))) + (_ .))) + ''') + >>> Tree.factorize(tree) + [(0, 5, 'TOP'), (0, 5, 'S'), (0, 1, 'NP'), (1, 4, 'VP'), (2, 4, 'S'), (2, 4, 'VP'), (3, 4, 'NP')] + >>> Tree.factorize(tree, delete_labels={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}) + [(0, 5, 'S'), (0, 1, 'NP'), (1, 4, 'VP'), (2, 4, 'S'), (2, 4, 'VP'), (3, 4, 'NP')] + + .. _EVALB: + https://nlp.cs.nyu.edu/evalb/ + """ + + def track(tree, i): + label = tree.label() + if delete_labels is not None and label in delete_labels: + label = None + if equal_labels is not None: + label = equal_labels.get(label, label) + if len(tree) == 1 and not isinstance(tree[0], nltk.Tree): + return (i + 1 if label is not None else i), [] + j, spans = i, [] + for child in tree: + j, s = track(child, j) + spans += s + if label is not None and j > i: + spans = [(i, j, label)] + spans + return j, spans + + return track(tree, 0)[1] + + @classmethod + def build(cls, tree, sequence): + r""" + Builds a constituency tree from the sequence. The sequence is generated in pre-order. + During building the tree, the sequence is de-binarized to the original format (i.e., + the suffixes ``|<>`` are ignored, the collapsed labels are recovered). + + Args: + tree (nltk.tree.Tree): + An empty tree that provides a base for building a result tree. + sequence (list[tuple]): + A list of tuples used for generating a tree. + Each tuple consits of the indices of left/right span boundaries and label of the span. + + Returns: + A result constituency tree. + + Examples: + >>> tree = Tree.totree(['She', 'enjoys', 'playing', 'tennis', '.'], 'TOP') + >>> sequence = [(0, 5, 'S'), (0, 4, 'S|<>'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP|<>'), + (2, 4, 'S+VP'), (2, 3, 'VP|<>'), (3, 4, 'NP'), (4, 5, 'S|<>')] + >>> print(Tree.build(tree, sequence)) + (TOP + (S + (NP (_ She)) + (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis))))) + (_ .))) + """ + + root = tree.label() + leaves = [subtree for subtree in tree.subtrees() + if not isinstance(subtree[0], nltk.Tree)] + + def track(node): + i, j, label = next(node) + if j == i + 1: + children = [leaves[i]] + else: + children = track(node) + track(node) + if label.endswith('|<>'): + return children + labels = label.split('+') + tree = nltk.Tree(labels[-1], children) + for label in reversed(labels[:-1]): + tree = nltk.Tree(label, [tree]) + return [tree] + + return nltk.Tree(root, track(iter(sequence))) + + def load(self, data, max_len=None, **kwargs): + r""" + Args: + data (list[list] or str): + A list of instances or a filename. + max_len (int): + Sentences exceeding the length will be discarded. Default: ``None``. + + Returns: + A list of :class:`TreeSentence` instances. + """ + if isinstance(data, str): + with open(data, 'r') as f: + trees = [nltk.Tree.fromstring(string) for string in f] + self.root = trees[0].label() + else: + data = [data] if isinstance(data[0], str) else data + trees = [self.totree(i, self.root) for i in data] + + i, sentences = 0, [] + for tree in progress_bar(trees, leave=False): + if len(tree) == 1 and not isinstance(tree[0][0], nltk.Tree): + continue + sentences.append(TreeSentence(self, tree)) + i += 1 + if max_len is not None: + sentences = [i for i in sentences if len(i) < max_len] + + return sentences + + +class TreeSentence(Sentence): + r""" + Args: + transform (Tree): + A :class:`Tree` object. + tree (nltk.tree.Tree): + A :class:`nltk.tree.Tree` object. + """ + + def __init__(self, transform, tree): + super().__init__(transform) + + # the values contain words, pos tags, raw trees, and spans + # the tree is first left-binarized before factorized + # spans are the factorization of tree traversed in pre-order + self.values = [*zip(*tree.pos()), + tree, + Tree.factorize(Tree.binarize(tree)[0])] + + def __repr__(self): + return self.values[-2].pformat(1000000) diff --git a/underthesea/file_utils.py b/underthesea/file_utils.py index 797e84a6..31fa9229 100644 --- a/underthesea/file_utils.py +++ b/underthesea/file_utils.py @@ -1,18 +1,17 @@ """ Utilities for working with the local data set cache. Copied from flair """ +import mmap from os.path import join from pathlib import Path import os -import logging import shutil import tempfile import re from urllib.parse import urlparse from tqdm import tqdm as _tqdm import requests - -logger = logging.getLogger('underthesea') +from underthesea import logger CACHE_ROOT = os.path.expanduser(os.path.join('~', '.underthesea')) DATASETS_FOLDER = join(CACHE_ROOT, "datasets") @@ -127,3 +126,16 @@ def tqdm(*args, **kwargs): } return _tqdm(*args, **new_kwargs) + + +def load_big_file(f: str) -> mmap.mmap: + r""" + Workaround for loading a big pickle file. Files over 2GB cause pickle errors on certin Mac and Windows distributions. + Source: flairNLP + """ + logger.info(f"loading file {f}") + with open(f, "rb") as f_in: + # mmap seems to be much more memory efficient + bf = mmap.mmap(f_in.fileno(), 0, access=mmap.ACCESS_READ) + f_in.close() + return bf diff --git a/underthesea/models/dependency_parser.py b/underthesea/models/dependency_parser.py new file mode 100644 index 00000000..ccbf5e3b --- /dev/null +++ b/underthesea/models/dependency_parser.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- +import os +from datetime import datetime +from underthesea import logger, device +from underthesea.data import progress_bar +import torch +import torch.nn as nn +from underthesea.models.nn import Model +from underthesea.modules.model import BiaffineDependencyModel +from underthesea.utils.sp_config import Config +from underthesea.utils.sp_data import Dataset +from underthesea.utils.sp_field import Field +from underthesea.utils.sp_fn import ispunct +from underthesea.utils.sp_init import PRETRAINED +from underthesea.utils.sp_metric import AttachmentMetric + + +class DependencyParser(object): + NAME = None + MODEL = None + + def __init__(self, args, model, transform): + self.args = args + self.model = model + self.transform = transform + + def evaluate(self, data, buckets=8, batch_size=5000, **kwargs): + args = self.args.update(locals()) + + self.transform.train() + logger.info("Loading the data") + dataset = Dataset(self.transform, data) + dataset.build(args.batch_size, args.buckets) + logger.info(f"\n{dataset}") + + logger.info("Evaluating the dataset") + start = datetime.now() + loss, metric = self._evaluate(dataset.loader) + elapsed = datetime.now() - start + logger.info(f"loss: {loss:.4f} - {metric}") + logger.info(f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s") + + return loss, metric + + def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, **kwargs): + args = self.args.update(locals()) + if not args: + args = locals() + args.update(kwargs) + args = type('Args', (object,), locals()) + + self.transform.eval() + if args.prob: + self.transform.append(Field('probs')) + + logger.info("Loading the data") + dataset = Dataset(self.transform, data) + dataset.build(args.batch_size, args.buckets) + logger.info(f"\n{dataset}") + + logger.info("Making predictions on the dataset") + start = datetime.now() + preds = self._predict(dataset.loader) + elapsed = datetime.now() - start + + for name, value in preds.items(): + setattr(dataset, name, value) + if pred is not None: + logger.info(f"Saving predicted results to {pred}") + self.transform.save(pred, dataset.sentences) + logger.info(f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s") + + return dataset + + def _train(self, loader): + raise NotImplementedError + + @torch.no_grad() + def _evaluate(self, loader): + raise NotImplementedError + + @torch.no_grad() + def _predict(self, loader): + raise NotImplementedError + + @classmethod + def build(cls, path, **kwargs): + raise NotImplementedError + + @classmethod + def load(cls, path, **kwargs): + r""" + Loads a parser with data fields and pretrained model parameters. + + Args: + path (str): + - a string with the shortcut name of a pretrained parser defined in ``supar.PRETRAINED`` + to load from cache or download, e.g., ``'crf-dep-en'``. + - a path to a directory containing a pre-trained parser, e.g., `.//model`. + kwargs (dict): + A dict holding the unconsumed arguments that can be used to update the configurations and initiate the model. + + Examples: + >>> # from supar import Parser + >>> # parser = Parser.load('biaffine-dep-en') + >>> # parser = Parser.load('./ptb.biaffine.dependency.char') + """ + + args = Config(**locals()) + args.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + if os.path.exists(path): + state = torch.load(path) + else: + path = PRETRAINED[path] if path in PRETRAINED else path + state = torch.hub.load_state_dict_from_url(path) + try: + cls = BiaffineDependencyParserSupar + except Exception as e: + print(e) + + state['args'].update(args) + args = state['args'] + model = cls.MODEL(**args) + model.load_pretrained(state['pretrained']) + model.load_state_dict(state['state_dict'], False) + model.to(device) + transform = state['transform'] + return cls(args, model, transform) + + def save(self, path): + model = self.model + if hasattr(model, 'module'): + model = self.model.module + + args = model.args + + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + pretrained = state_dict.pop('pretrained.weight', None) + state = {'name': self.NAME, + 'args': args, + 'state_dict': state_dict, + 'pretrained': pretrained, + 'transform': self.transform} + torch.save(state, path) + + +class BiaffineDependencyParser(Model): + def __init__(self, embeddings='char', embed=False): + self.embeddings = embeddings + self.embed = embed + + +class BiaffineDependencyParserSupar(DependencyParser): + r""" + The implementation of Biaffine Dependency Parser. + + References: + - Timothy Dozat and Christopher D. Manning. 2017. + `Deep Biaffine Attention for Neural Dependency Parsing`_. + + .. _Deep Biaffine Attention for Neural Dependency Parsing: + https://openreview.net/forum?id=Hk95PK9le + """ + + NAME = 'biaffine-dependency' + MODEL = BiaffineDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + feat = self.args.feat + device = self.args.device + except Exception: + feat = self.args['feat'] + device = self.args['device'] + if feat in ('char', 'bert'): + self.WORD, self.FEAT = self.transform.FORM + else: + self.WORD, self.FEAT = self.transform.FORM, self.transform.CPOS + self.ARC, self.REL = self.transform.HEAD, self.transform.DEPREL + self.puncts = torch.tensor([i + for s, i in self.WORD.vocab.stoi.items() + if ispunct(s)]).to(device) + + def evaluate(self, data, buckets=8, batch_size=5000, + punct=False, tree=True, proj=False, verbose=True, **kwargs): + r""" + Args: + data (str): + The data for evaluation, both list of instances and filename are allowed. + buckets (int): + The number of buckets that sentences are assigned to. Default: 32. + batch_size (int): + The number of tokens in each batch. Default: 5000. + punct (bool): + If ``False``, ignores the punctuations during evaluation. Default: ``False``. + tree (bool): + If ``True``, ensures to output well-formed trees. Default: ``False``. + proj (bool): + If ``True``, ensures to output projective trees. Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + kwargs (dict): + A dict holding the unconsumed arguments that can be used to update the configurations for evaluation. + + Returns: + The loss scalar and evaluation results. + """ + + return super().evaluate(**Config().update(locals())) + + def predict(self, data, pred=None, buckets=8, batch_size=5000, + prob=False, tree=True, proj=False, verbose=True, **kwargs): + r""" + Args: + data (list[list] or str): + The data for prediction, both a list of instances and filename are allowed. + pred (str): + If specified, the predicted results will be saved to the file. Default: ``None``. + buckets (int): + The number of buckets that sentences are assigned to. Default: 32. + batch_size (int): + The number of tokens in each batch. Default: 5000. + prob (bool): + If ``True``, outputs the probabilities. Default: ``False``. + tree (bool): + If ``True``, ensures to output well-formed trees. Default: ``False``. + proj (bool): + If ``True``, ensures to output projective trees. Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + kwargs (dict): + A dict holding the unconsumed arguments that can be used to update the configurations for prediction. + + Returns: + A :class:`~supar.utils.Dataset` object that stores the predicted results. + """ + + return super().predict(**Config().update(locals())) + + def _train(self, loader): + self.model.train() + + bar, metric = progress_bar(loader), AttachmentMetric() + + for words, feats, arcs, rels in bar: + self.optimizer.zero_grad() + + mask = words.ne(self.WORD.pad_index) + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + loss = self.model.loss(s_arc, s_rel, arcs, rels, mask) + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) + self.optimizer.step() + self.scheduler.step() + + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask) + # ignore all punctuation if not specified + if not self.args.punct: + mask &= words.unsqueeze(-1).ne(self.puncts).all(-1) + metric(arc_preds, rel_preds, arcs, rels, mask) + bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}") + + @torch.no_grad() + def _evaluate(self, loader): + self.model.eval() + + total_loss, metric = 0, AttachmentMetric() + + tree = self.args['tree'] + proj = self.args['proj'] + + for words, feats, arcs, rels in loader: + mask = words.ne(self.WORD.pad_index) + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + loss = self.model.loss(s_arc, s_rel, arcs, rels, mask) + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, tree, proj) + # ignore all punctuation if not specified + if not self.args['punct']: + mask &= words.unsqueeze(-1).ne(self.puncts).all(-1) + total_loss += loss.item() + metric(arc_preds, rel_preds, arcs, rels, mask) + total_loss /= len(loader) + + return total_loss, metric + + @torch.no_grad() + def _predict(self, loader): + self.model.eval() + try: + tree = self.args.tree + proj = self.args.proj + prob = self.args.prob + except Exception: + tree = self.args['tree'] + proj = self.args['proj'] + prob = self.args['prob'] + arcs, rels, probs = [], [], [] + for words, feats in progress_bar(loader): + mask = words.ne(self.WORD.pad_index) + # ignore the first token of each sentence + mask[:, 0] = 0 + lens = mask.sum(1).tolist() + s_arc, s_rel = self.model(words, feats) + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, + tree, proj) + arcs.extend(arc_preds[mask].split(lens)) + rels.extend(rel_preds[mask].split(lens)) + if prob: + arc_probs = s_arc.softmax(-1) + probs.extend([prob[1:i + 1, :i + 1].cpu() for i, prob in zip(lens, arc_probs.unbind())]) + arcs = [seq.tolist() for seq in arcs] + rels = [self.REL.vocab[seq.tolist()] for seq in rels] + preds = {'arcs': arcs, 'rels': rels} + if prob: + preds['probs'] = probs + + return preds diff --git a/underthesea/models/nn.py b/underthesea/models/nn.py new file mode 100644 index 00000000..6decbe7e --- /dev/null +++ b/underthesea/models/nn.py @@ -0,0 +1,68 @@ +import warnings +from abc import abstractmethod +from pathlib import Path +from typing import Union +import torch.nn +from underthesea import device, file_utils + + +class Model(torch.nn.Module): + r""" + Abstract base class for all downstream task models + Every new type of model must implement these methods. + + Source: FlairNLP + """ + + @abstractmethod + def evaluate(self): + r"""Evaluates the model. Returns a Result object containing evaluation + """ + pass + + def save(self, model_file: Union[str, Path]): + """ + Saves the current model to the provided file. + + Args: + model_file (Union[str, Path]): the model file + + """ + model_state = self._get_state_dict() + + torch.save(model_state, str(model_file), pickle_protocol=4) + + @staticmethod + @abstractmethod + def _fetch_model(model_name) -> str: + return model_name + + @staticmethod + @abstractmethod + def _init_model_with_state_dict(state): + """Initialize the model from a state dictionary. Implementing this enables the load() and load_checkpoint() + functionality.""" + pass + + @classmethod + def load(cls, model: Union[str, Path]): + """ + Loads the model from the given file. + :param model: the model file + :return: the loaded text classifier model + """ + model_file = cls._fetch_model(str(model)) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + # load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups + # see https://github.com/zalandoresearch/flair/issues/351 + f = file_utils.load_big_file(str(model_file)) + state = torch.load(f, map_location='cpu') + + model = cls._init_model_with_state_dict(state) + + model.eval() + model.to(device) + + return model diff --git a/underthesea/modules/__init__.py b/underthesea/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/underthesea/modules/base.py b/underthesea/modules/base.py new file mode 100644 index 00000000..637a031f --- /dev/null +++ b/underthesea/modules/base.py @@ -0,0 +1,471 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence +from torch.nn.modules.rnn import apply_permutation + + +class CharLSTM(nn.Module): + r""" + CharLSTM aims to generate character-level embeddings for tokens. + It summarizes the information of characters in each token to an embedding using a LSTM layer. + + Args: + n_char (int): + The number of characters. + n_embed (int): + The size of each embedding vector as input to LSTM. + n_out (int): + The size of each output vector. + pad_index (int): + The index of the padding token in the vocabulary. Default: 0. + """ + + def __init__(self, n_chars, n_embed, n_out, pad_index=0): + super().__init__() + + self.n_chars = n_chars + self.n_embed = n_embed + self.n_out = n_out + self.pad_index = pad_index + + # the embedding layer + self.embed = nn.Embedding(num_embeddings=n_chars, + embedding_dim=n_embed) + # the lstm layer + self.lstm = nn.LSTM(input_size=n_embed, + hidden_size=n_out // 2, + batch_first=True, + bidirectional=True) + + def __repr__(self): + s = f"{self.n_chars}, {self.n_embed}, " + s += f"n_out={self.n_out}, " + s += f"pad_index={self.pad_index}" + + return f"{self.__class__.__name__}({s})" + + def forward(self, x): + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. + Characters of all tokens. + Each token holds no more than `fix_len` characters, and the excess is cut off directly. + Returns: + ~torch.Tensor: + The embeddings of shape ``[batch_size, seq_len, n_out]`` derived from the characters. + """ + # [batch_size, seq_len, fix_len] + mask = x.ne(self.pad_index) + # [batch_size, seq_len] + lens = mask.sum(-1) + char_mask = lens.gt(0) + + # [n, fix_len, n_embed] + x = self.embed(x[char_mask]) + x = pack_padded_sequence(x, lengths=lens[char_mask].cpu(), batch_first=True, enforce_sorted=False) + x, (h, _) = self.lstm(x) + # [n, fix_len, n_out] + h = torch.cat(torch.unbind(h), -1) + # [batch_size, seq_len, n_out] + embed = h.new_zeros(*lens.shape, self.n_out) + embed = embed.masked_scatter_(char_mask.unsqueeze(-1), h) + + return embed + + +class IndependentDropout(nn.Module): + r""" + For :math:`N` tensors, they use different dropout masks respectively. + When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate, + and when all of them are dropped together, zeros are returned. + + Args: + p (float): + The probability of an element to be zeroed. Default: 0.5. + + Examples: + >>> x, y = torch.ones(1, 3, 5), torch.ones(1, 3, 5) + >>> x, y = IndependentDropout()(x, y) + >>> x + tensor([[[1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0.], + [2., 2., 2., 2., 2.]]]) + >>> y + tensor([[[1., 1., 1., 1., 1.], + [2., 2., 2., 2., 2.], + [0., 0., 0., 0., 0.]]]) + """ + + def __init__(self, p=0.5): + super().__init__() + + self.p = p + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p})" + + def forward(self, *items): + r""" + Args: + items (list[~torch.Tensor]): + A list of tensors that have the same shape except the last dimension. + Returns: + The returned tensors are of the same shape as `items`. + """ + + if self.training: + masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items] + total = sum(masks) + scale = len(items) / total.max(torch.ones_like(total)) + masks = [mask * scale for mask in masks] + items = [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)] + + return items + + +class BiLSTM(nn.Module): + r""" + BiLSTM is an variant of the vanilla bidirectional LSTM adopted by Biaffine Parser + with the only difference of the dropout strategy. + It drops nodes in the LSTM layers (input and recurrent connections) + and applies the same dropout mask at every recurrent timesteps. + + APIs are roughly the same as :class:`~torch.nn.LSTM` except that we remove the ``bidirectional`` option + and only allows :class:`~torch.nn.utils.rnn.PackedSequence` as input. + + References: + - Timothy Dozat and Christopher D. Manning. 2017. + `Deep Biaffine Attention for Neural Dependency Parsing`_. + + Args: + input_size (int): + The number of expected features in the input. + hidden_size (int): + The number of features in the hidden state `h`. + num_layers (int): + The number of recurrent layers. Default: 1. + dropout (float): + If non-zero, introduces a :class:`SharedDropout` layer on the outputs of each LSTM layer except the last layer. + Default: 0. + + .. _Deep Biaffine Attention for Neural Dependency Parsing: + https://openreview.net/forum?id=Hk95PK9le + """ + + def __init__(self, input_size, hidden_size, num_layers=1, dropout=0): + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.dropout = dropout + + self.f_cells = nn.ModuleList() + self.b_cells = nn.ModuleList() + for _ in range(self.num_layers): + self.f_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)) + self.b_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)) + input_size = hidden_size * 2 + + self.reset_parameters() + + def __repr__(self): + s = f"{self.input_size}, {self.hidden_size}" + if self.num_layers > 1: + s += f", num_layers={self.num_layers}" + if self.dropout > 0: + s += f", dropout={self.dropout}" + + return f"{self.__class__.__name__}({s})" + + def reset_parameters(self): + for param in self.parameters(): + # apply orthogonal_ to weight + if len(param.shape) > 1: + nn.init.orthogonal_(param) + # apply zeros_ to bias + else: + nn.init.zeros_(param) + + def permute_hidden(self, hx, permutation): + if permutation is None: + return hx + h = apply_permutation(hx[0], permutation) + c = apply_permutation(hx[1], permutation) + + return h, c + + def layer_forward(self, x, hx, cell, batch_sizes, reverse=False): + hx_0 = hx_i = hx + hx_n, output = [], [] + steps = reversed(range(len(x))) if reverse else range(len(x)) + if self.training: + hid_mask = SharedDropout.get_mask(hx_0[0], self.dropout) + + for t in steps: + last_batch_size, batch_size = len(hx_i[0]), batch_sizes[t] + if last_batch_size < batch_size: + hx_i = [torch.cat((h, ih[last_batch_size:batch_size])) for h, ih in zip(hx_i, hx_0)] + else: + hx_n.append([h[batch_size:] for h in hx_i]) + hx_i = [h[:batch_size] for h in hx_i] + hx_i = [h for h in cell(x[t], hx_i)] + output.append(hx_i[0]) + if self.training: + hx_i[0] = hx_i[0] * hid_mask[:batch_size] + if reverse: + hx_n = hx_i + output.reverse() + else: + hx_n.append(hx_i) + hx_n = [torch.cat(h) for h in zip(*reversed(hx_n))] + output = torch.cat(output) + + return output, hx_n + + def forward(self, sequence, hx=None): + r""" + Args: + sequence (~torch.nn.utils.rnn.PackedSequence): + A packed variable length sequence. + hx (~torch.Tensor, ~torch.Tensor): + A tuple composed of two tensors `h` and `c`. + `h` of shape ``[num_layers*2, batch_size, hidden_size]`` contains the initial hidden state + for each element in the batch. + `c` of shape ``[num_layers*2, batch_size, hidden_size]`` contains the initial cell state + for each element in the batch. + If `hx` is not provided, both `h` and `c` default to zero. + Default: ``None``. + + Returns: + ~torch.nn.utils.rnn.PackedSequence, (~torch.Tensor, ~torch.Tensor): + The first is a packed variable length sequence. + The second is a tuple of tensors `h` and `c`. + `h` of shape ``[num_layers*2, batch_size, hidden_size]`` contains the hidden state for `t = seq_len`. + Like output, the layers can be separated using ``h.view(num_layers, 2, batch_size, hidden_size)`` + and similarly for c. + `c` of shape ``[num_layers*2, batch_size, hidden_size]`` contains the cell state for `t = seq_len`. + """ + x, batch_sizes = sequence.data, sequence.batch_sizes.tolist() + batch_size = batch_sizes[0] + h_n, c_n = [], [] + + if hx is None: + ih = x.new_zeros(self.num_layers * 2, batch_size, self.hidden_size) + h, c = ih, ih + else: + h, c = self.permute_hidden(hx, sequence.sorted_indices) + h = h.view(self.num_layers, 2, batch_size, self.hidden_size) + c = c.view(self.num_layers, 2, batch_size, self.hidden_size) + + for i in range(self.num_layers): + x = torch.split(x, batch_sizes) + if self.training: + mask = SharedDropout.get_mask(x[0], self.dropout) + x = [i * mask[:len(i)] for i in x] + x_f, (h_f, c_f) = self.layer_forward(x=x, + hx=(h[i, 0], c[i, 0]), + cell=self.f_cells[i], + batch_sizes=batch_sizes) + x_b, (h_b, c_b) = self.layer_forward(x=x, + hx=(h[i, 1], c[i, 1]), + cell=self.b_cells[i], + batch_sizes=batch_sizes, + reverse=True) + x = torch.cat((x_f, x_b), -1) + h_n.append(torch.stack((h_f, h_b))) + c_n.append(torch.stack((c_f, c_b))) + x = PackedSequence(x, + sequence.batch_sizes, + sequence.sorted_indices, + sequence.unsorted_indices) + hx = torch.cat(h_n, 0), torch.cat(c_n, 0) + hx = self.permute_hidden(hx, sequence.unsorted_indices) + + return x, hx + + +class SharedDropout(nn.Module): + r""" + SharedDropout differs from the vanilla dropout strategy in that + the dropout mask is shared across one dimension. + + Args: + p (float): + The probability of an element to be zeroed. Default: 0.5. + batch_first (bool): + If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``. + Default: ``True``. + + Examples: + >>> x = torch.ones(1, 3, 5) + >>> nn.Dropout()(x) + tensor([[[0., 2., 2., 0., 0.], + [2., 2., 0., 2., 2.], + [2., 2., 2., 2., 0.]]]) + >>> SharedDropout()(x) + tensor([[[2., 0., 2., 0., 2.], + [2., 0., 2., 0., 2.], + [2., 0., 2., 0., 2.]]]) + """ + + def __init__(self, p=0.5, batch_first=True): + super().__init__() + + self.p = p + self.batch_first = batch_first + + def __repr__(self): + s = f"p={self.p}" + if self.batch_first: + s += f", batch_first={self.batch_first}" + + return f"{self.__class__.__name__}({s})" + + def forward(self, x): + r""" + Args: + x (~torch.Tensor): + A tensor of any shape. + Returns: + The returned tensor is of the same shape as `x`. + """ + + if self.training: + if self.batch_first: + mask = self.get_mask(x[:, 0], self.p).unsqueeze(1) + else: + mask = self.get_mask(x[0], self.p) + x = x * mask + + return x + + @staticmethod + def get_mask(x, p): + return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p) + + +class MLP(nn.Module): + r""" + Applies a linear transformation together with :class:`~torch.nn.LeakyReLU` activation to the incoming tensor: + :math:`y = \mathrm{LeakyReLU}(x A^T + b)` + + Args: + n_in (~torch.Tensor): + The size of each input feature. + n_out (~torch.Tensor): + The size of each output feature. + dropout (float): + If non-zero, introduce a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0. + """ + + def __init__(self, n_in, n_out, dropout=0): + super().__init__() + + self.n_in = n_in + self.n_out = n_out + self.linear = nn.Linear(n_in, n_out) + self.activation = nn.LeakyReLU(negative_slope=0.1) + self.dropout = SharedDropout(p=dropout) + + self.reset_parameters() + + def __repr__(self): + s = f"n_in={self.n_in}, n_out={self.n_out}" + if self.dropout.p > 0: + s += f", dropout={self.dropout.p}" + + return f"{self.__class__.__name__}({s})" + + def reset_parameters(self): + nn.init.orthogonal_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x): + r""" + Args: + x (~torch.Tensor): + The size of each input feature is `n_in`. + + Returns: + A tensor with the size of each output feature `n_out`. + """ + + x = self.linear(x) + x = self.activation(x) + x = self.dropout(x) + + return x + + +class Biaffine(nn.Module): + r""" + Biaffine layer for first-order scoring. + + This function has a tensor of weights :math:`W` and bias terms if needed. + The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y`, + in which :math:`x` and :math:`y` can be concatenated with bias terms. + + References: + - Timothy Dozat and Christopher D. Manning. 2017. + `Deep Biaffine Attention for Neural Dependency Parsing`_. + + Args: + n_in (int): + The size of the input feature. + n_out (int): + The number of output channels. + bias_x (bool): + If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``. + bias_y (bool): + If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``. + + .. _Deep Biaffine Attention for Neural Dependency Parsing: + https://openreview.net/forum?id=Hk95PK9le + """ + + def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True): + super().__init__() + + self.n_in = n_in + self.n_out = n_out + self.bias_x = bias_x + self.bias_y = bias_y + self.weight = nn.Parameter(torch.Tensor(n_out, n_in + bias_x, n_in + bias_y)) + + self.reset_parameters() + + def __repr__(self): + s = f"n_in={self.n_in}, n_out={self.n_out}" + if self.bias_x: + s += f", bias_x={self.bias_x}" + if self.bias_y: + s += f", bias_y={self.bias_y}" + + return f"{self.__class__.__name__}({s})" + + def reset_parameters(self): + nn.init.zeros_(self.weight) + + def forward(self, x, y): + r""" + Args: + x (torch.Tensor): ``[batch_size, seq_len, n_in]``. + y (torch.Tensor): ``[batch_size, seq_len, n_in]``. + + Returns: + ~torch.Tensor: + A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``. + If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically. + """ + + if self.bias_x: + x = torch.cat((x, torch.ones_like(x[..., :1])), -1) + if self.bias_y: + y = torch.cat((y, torch.ones_like(y[..., :1])), -1) + # [batch_size, n_out, seq_len, seq_len] + s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) + # remove dim 1 if n_out == 1 + s = s.squeeze(1) + + return s diff --git a/underthesea/modules/bert.py b/underthesea/modules/bert.py new file mode 100644 index 00000000..4b5234da --- /dev/null +++ b/underthesea/modules/bert.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoConfig, AutoModel + +from underthesea.modules.scalar_mix import ScalarMix + + +class BertEmbedding(nn.Module): + r""" + A module that directly utilizes the pretrained models in `transformers`_ to produce BERT representations. + + While mainly tailored to provide input preparation and post-processing for the BERT model, + it is also compatiable with other pretrained language models like XLNet, RoBERTa and ELECTRA, etc. + + Args: + model (str): + Path or name of the pretrained models registered in `transformers`_, e.g., ``'bert-base-cased'``. + n_layers (int): + The number of layers from the model to use. + If 0, uses all layers. + n_out (int): + The requested size of the embeddings. + If 0, uses the size of the pretrained embedding model. + pad_index (int): + The index of the padding token in the BERT vocabulary. Default: 0. + max_len (int): + Sequences should not exceed the specfied max length. Default: 512. + dropout (float): + The dropout ratio of BERT layers. Default: 0. + This value will be passed into the :class:`ScalarMix` layer. + requires_grad (bool): + If ``True``, the model parameters will be updated together with the downstream task. + Default: ``False``. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, model, n_layers, n_out, pad_index=0, max_len=512, dropout=0, requires_grad=False): + super().__init__() + + self.model = model + self.bert = AutoModel.from_pretrained(model, + config=AutoConfig.from_pretrained(model, output_hidden_states=True)) + self.bert = self.bert.requires_grad_(requires_grad) + self.n_layers = n_layers or self.bert.config.num_hidden_layers + self.hidden_size = self.bert.config.hidden_size + self.n_out = n_out or self.hidden_size + self.pad_index = pad_index + self.max_len = max_len + self.dropout = dropout + self.requires_grad = requires_grad + + self.scalar_mix = ScalarMix(self.n_layers, dropout) + self.projection = nn.Linear(self.hidden_size, self.n_out, False) if self.hidden_size != n_out else nn.Identity() + + def __repr__(self): + s = f"{self.model}, n_layers={self.n_layers}, n_out={self.n_out}, pad_index={self.pad_index}" + if self.max_len is not None: + s += f", max_len={self.max_len}" + if self.dropout > 0: + s += f", dropout={self.dropout}" + if self.requires_grad: + s += f", requires_grad={self.requires_grad}" + + return f"{self.__class__.__name__}({s})" + + def forward(self, subwords): + r""" + Args: + subwords (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. + + Returns: + ~torch.Tensor: + BERT embeddings of shape ``[batch_size, seq_len, n_out]``. + """ + batch_size, seq_len, fix_len = subwords.shape + if self.max_len and seq_len > self.max_len: + raise RuntimeError(f"Token indices sequence length is longer than the specified max length " + f"({seq_len} > {self.max_len})") + + mask = subwords.ne(self.pad_index) + lens = mask.sum((1, 2)) + # [batch_size, n_subwords] + subwords = pad_sequence(subwords[mask].split(lens.tolist()), True) + bert_mask = pad_sequence(mask[mask].split(lens.tolist()), True) + # return the hidden states of all layers + bert = self.bert(subwords, attention_mask=bert_mask.float())[-1] + # [n_layers, batch_size, n_subwords, hidden_size] + bert = bert[-self.n_layers:] + # [batch_size, n_subwords, hidden_size] + bert = self.scalar_mix(bert) + # [batch_size, n_subwords] + bert_lens = mask.sum(-1) + bert_lens = bert_lens.masked_fill_(bert_lens.eq(0), 1) + # [batch_size, seq_len, fix_len, hidden_size] + embed = bert.new_zeros(*mask.shape, self.hidden_size) + embed = embed.masked_scatter_(mask.unsqueeze(-1), bert[bert_mask]) + # [batch_size, seq_len, hidden_size] + embed = embed.sum(2) / bert_lens.unsqueeze(-1) + embed = self.projection(embed) + + return embed diff --git a/underthesea/modules/model.py b/underthesea/modules/model.py new file mode 100644 index 00000000..68696780 --- /dev/null +++ b/underthesea/modules/model.py @@ -0,0 +1,283 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +from underthesea.data import CoNLL +from underthesea.modules.base import CharLSTM, IndependentDropout, BiLSTM, SharedDropout, MLP, Biaffine +from underthesea.modules.bert import BertEmbedding +from underthesea.utils.sp_alg import eisner, mst + + +class BiaffineDependencyModel(nn.Module): + r""" + The implementation of Biaffine Dependency Parser. + + References: + - Timothy Dozat and Christopher D. Manning. 2017. + `Deep Biaffine Attention for Neural Dependency Parsing`_. + + Args: + n_words (int): + The size of the word vocabulary. + n_feats (int): + The size of the feat vocabulary. + n_rels (int): + The number of labels in the treebank. + feat (str): + Specifies which type of additional feature to use: ``'char'`` | ``'bert'`` | ``'tag'``. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained langugae models like XLNet are also feasible. + ``'tag'``: POS tag embeddings. + Default: ``'char'``. + n_embed (int): + The size of word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if ``feat='char'``. Default: 50. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'`` and ``'xlnet-base-cased'``. + This is required if ``feat='bert'``. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use. Required if ``feat='bert'``. + The final outputs would be the weight sum of the hidden states of these layers. + Default: 4. + max_len (int): + Sequences should not exceed the specfied max length. Default: ``None``. + mix_dropout (float): + The dropout ratio of BERT layers. Required if ``feat='bert'``. Default: .0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_lstm_hidden (int): + The size of LSTM hidden states. Default: 400. + n_lstm_layers (int): + The number of LSTM layers. Default: 3. + lstm_dropout (float): + The dropout ratio of LSTM. Default: .33. + n_mlp_arc (int): + Arc MLP size. Default: 500. + n_mlp_rel (int): + Label MLP size. Default: 100. + mlp_dropout (float): + The dropout ratio of MLP layers. Default: .33. + feat_pad_index (int): + The index of the padding token in the feat vocabulary. Default: 0. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _Deep Biaffine Attention for Neural Dependency Parsing: + https://openreview.net/forum?id=Hk95PK9le + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_feats, + n_rels, + feat='char', + n_embed=100, + n_feat_embed=100, + n_char_embed=50, + bert=None, + n_bert_layers=4, + max_len=None, + mix_dropout=.0, + embed_dropout=.33, + n_lstm_hidden=400, + n_lstm_layers=3, + lstm_dropout=.33, + n_mlp_arc=500, + n_mlp_rel=100, + mlp_dropout=.33, + feat_pad_index=0, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__() + + self.args = { + "n_words": n_words, + "n_feats": n_feats, + "n_rels": n_rels, + "feat": feat, + 'tree': False, + 'proj': False, + 'punct': False + } + + # the embedding layer + self.word_embed = nn.Embedding(num_embeddings=n_words, + embedding_dim=n_embed) + if feat == 'char': + self.feat_embed = CharLSTM(n_chars=n_feats, + n_embed=n_char_embed, + n_out=n_feat_embed, + pad_index=feat_pad_index) + elif feat == 'bert': + self.feat_embed = BertEmbedding(model=bert, + n_layers=n_bert_layers, + n_out=n_feat_embed, + pad_index=feat_pad_index, + max_len=max_len, + dropout=mix_dropout) + self.n_feat_embed = self.feat_embed.n_out + elif feat == 'tag': + self.feat_embed = nn.Embedding(num_embeddings=n_feats, + embedding_dim=n_feat_embed) + else: + raise RuntimeError("The feat type should be in ['char', 'bert', 'tag'].") + self.embed_dropout = IndependentDropout(p=embed_dropout) + + # the lstm layer + self.lstm = BiLSTM(input_size=n_embed + n_feat_embed, + hidden_size=n_lstm_hidden, + num_layers=n_lstm_layers, + dropout=lstm_dropout) + self.lstm_dropout = SharedDropout(p=lstm_dropout) + + # the MLP layers + self.mlp_arc_d = MLP(n_in=n_lstm_hidden * 2, + n_out=n_mlp_arc, + dropout=mlp_dropout) + self.mlp_arc_h = MLP(n_in=n_lstm_hidden * 2, + n_out=n_mlp_arc, + dropout=mlp_dropout) + self.mlp_rel_d = MLP(n_in=n_lstm_hidden * 2, + n_out=n_mlp_rel, + dropout=mlp_dropout) + self.mlp_rel_h = MLP(n_in=n_lstm_hidden * 2, + n_out=n_mlp_rel, + dropout=mlp_dropout) + + # the Biaffine layers + self.arc_attn = Biaffine(n_in=n_mlp_arc, + bias_x=True, + bias_y=False) + self.rel_attn = Biaffine(n_in=n_mlp_rel, + n_out=n_rels, + bias_x=True, + bias_y=True) + self.criterion = nn.CrossEntropyLoss() + self.pad_index = pad_index + self.unk_index = unk_index + + def load_pretrained(self, embed=None): + if embed is not None: + self.pretrained = nn.Embedding.from_pretrained(embed) + nn.init.zeros_(self.word_embed.weight) + return self + + def forward(self, words, feats): + r""" + Args: + words (torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (torch.LongTensor): + Feat indices. + If feat is ``'char'`` or ``'bert'``, the size of feats should be ``[batch_size, seq_len, fix_len]``. + if ``'tag'``, the size is ``[batch_size, seq_len]``. + + Returns: + torch.Tensor, torch.Tensor: + The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible arcs. + The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds + scores of all possible labels on each arc. + """ + + batch_size, seq_len = words.shape + # get the mask and lengths of given batch + mask = words.ne(self.pad_index) + ext_words = words + # set the indices larger than num_embeddings to unk_index + if hasattr(self, 'pretrained'): + ext_mask = words.ge(self.word_embed.num_embeddings) + ext_words = words.masked_fill(ext_mask, self.unk_index) + + # get outputs from embedding layers + word_embed = self.word_embed(ext_words) + if hasattr(self, 'pretrained'): + word_embed += self.pretrained(words) + feat_embed = self.feat_embed(feats) + word_embed, feat_embed = self.embed_dropout(word_embed, feat_embed) + # concatenate the word and feat representations + embed = torch.cat((word_embed, feat_embed), -1) + + x = pack_padded_sequence(embed, mask.sum(1), True, False) + x, _ = self.lstm(x) + x, _ = pad_packed_sequence(x, True, total_length=seq_len) + x = self.lstm_dropout(x) + + # apply MLPs to the BiLSTM output states + arc_d = self.mlp_arc_d(x) + arc_h = self.mlp_arc_h(x) + rel_d = self.mlp_rel_d(x) + rel_h = self.mlp_rel_h(x) + + # [batch_size, seq_len, seq_len] + s_arc = self.arc_attn(arc_d, arc_h) + # [batch_size, seq_len, seq_len, n_rels] + s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1) + # set the scores that exceed the length of each sentence to -inf + s_arc.masked_fill_(~mask.unsqueeze(1), float('-inf')) + + return s_arc, s_rel + + def loss(self, s_arc, s_rel, arcs, rels, mask): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + arcs (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard arcs. + rels (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + + Returns: + ~torch.Tensor: + The training loss. + """ + + s_arc, arcs = s_arc[mask], arcs[mask] + s_rel, rels = s_rel[mask], rels[mask] + s_rel = s_rel[torch.arange(len(arcs)), arcs] + arc_loss = self.criterion(s_arc, arcs) + rel_loss = self.criterion(s_rel, rels) + + return arc_loss + rel_loss + + def decode(self, s_arc, s_rel, mask, tree=False, proj=False): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + tree (bool): + If ``True``, ensures to output well-formed trees. Default: ``False``. + proj (bool): + If ``True``, ensures to output projective trees. Default: ``False``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + Predicted arcs and labels of shape ``[batch_size, seq_len]``. + """ + + lens = mask.sum(1) + arc_preds = s_arc.argmax(-1) + bad = [not CoNLL.istree(seq[1:i + 1], proj) + for i, seq in zip(lens.tolist(), arc_preds.tolist())] + if tree and any(bad): + alg = eisner if proj else mst + arc_preds[bad] = alg(s_arc[bad], mask[bad]) + rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1) + + return arc_preds, rel_preds diff --git a/underthesea/modules/scalar_mix.py b/underthesea/modules/scalar_mix.py new file mode 100644 index 00000000..8d528d6a --- /dev/null +++ b/underthesea/modules/scalar_mix.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn + + +class ScalarMix(nn.Module): + r""" + Computes a parameterised scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)` + where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters. + + Args: + n_layers (int): + The number of layers to be mixed, i.e., :math:`N`. + dropout (float): + The dropout ratio of the layer weights. + If dropout > 0, then for each scalar weight, adjust its softmax weight mass to 0 + with the dropout probability (i.e., setting the unnormalized weight to -inf). + This effectively redistributes the dropped probability mass to all other weights. + Default: 0. + """ + + def __init__(self, n_layers, dropout=0): + super().__init__() + + self.n_layers = n_layers + + self.weights = nn.Parameter(torch.zeros(n_layers)) + self.gamma = nn.Parameter(torch.tensor([1.0])) + self.dropout = nn.Dropout(dropout) + + def __repr__(self): + s = f"n_layers={self.n_layers}" + if self.dropout.p > 0: + s += f", dropout={self.dropout.p}" + + return f"{self.__class__.__name__}({s})" + + def forward(self, tensors): + r""" + Args: + tensors (list[~torch.Tensor]): + :math:`N` tensors to be mixed. + + Returns: + The mixture of :math:`N` tensors. + """ + + normed_weights = self.dropout(self.weights.softmax(-1)) + weighted_sum = sum(w * h for w, h in zip(normed_weights, tensors)) + + return self.gamma * weighted_sum diff --git a/underthesea/trainers/parser_trainer.py b/underthesea/trainers/parser_trainer.py new file mode 100644 index 00000000..cb038daa --- /dev/null +++ b/underthesea/trainers/parser_trainer.py @@ -0,0 +1,196 @@ +import os +from datetime import timedelta, datetime +from pathlib import Path +from typing import Union +import torch.distributed as dist +from torch.optim import Adam +from torch.optim.lr_scheduler import ExponentialLR +from underthesea import logger, device +from underthesea.data import CoNLL +from underthesea.models.dependency_parser import BiaffineDependencyParserSupar +from underthesea.modules.model import BiaffineDependencyModel +from underthesea.utils.sp_common import pad, unk, bos +from underthesea.utils.sp_config import Config +from underthesea.utils.sp_data import Dataset +from underthesea.utils.sp_embedding import Embedding +from underthesea.utils.sp_field import Field, SubwordField +from underthesea.utils.sp_metric import Metric +from underthesea.utils.sp_parallel import DistributedDataParallel as DDP, is_master + + +class ParserTrainer: + def __init__(self, parser, corpus): + self.parser = parser + self.corpus = corpus + + def train( + self, base_path: Union[Path, str], + fix_len=20, + min_freq=2, + buckets=32, + batch_size=5000, + punct=False, + tree=False, + proj=False, + lr=2e-3, + mu=.9, + nu=.9, + epsilon=1e-12, + clip=5.0, + decay=.75, + decay_steps=5000, + patience=100, + verbose=True, + max_epochs=10, + **kwargs + ): + r""" + Train any class that implement model interface + + Args: + base_path (object): Main path to which all output during training is logged and models are saved + max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed. + verbose: + patience: + decay_steps: + decay: + clip: + epsilon: + nu: + mu: + lr: + proj: + tree: + punct: + batch_size: + buckets: + min_freq: + fix_len: + + + """ + ################################################################################################################ + # BUILD + ################################################################################################################ + locals_args = { + 'base_path': base_path, + 'fix_len': fix_len, + 'min_freq': min_freq, + 'max_epochs': max_epochs + } + args = Config(**locals_args) + args.feat = self.parser.embeddings + args.embed = self.parser.embed + os.makedirs(os.path.dirname(base_path), exist_ok=True) + + logger.info("Building the fields") + WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) + if args.feat == 'char': + FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len) + elif args.feat == 'bert': + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.bert) + args.max_len = min(args.max_len or tokenizer.max_len, tokenizer.max_len) + FEAT = SubwordField('bert', + pad=tokenizer.pad_token, + unk=tokenizer.unk_token, + bos=tokenizer.bos_token or tokenizer.cls_token, + fix_len=args.fix_len, + tokenize=tokenizer.tokenize) + FEAT.vocab = tokenizer.get_vocab() + else: + FEAT = Field('tags', bos=bos) + + ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs) + REL = Field('rels', bos=bos) + if args.feat in ('char', 'bert'): + transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL) + else: + transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL) + + train = Dataset(transform, self.corpus.train) + WORD.build(train, min_freq, (Embedding.load(args.embed, unk) if self.parser.embed else None)) + FEAT.build(train) + REL.build(train) + args.update({ + 'n_words': WORD.vocab.n_init, + 'n_feats': len(FEAT.vocab), + 'n_rels': len(REL.vocab), + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'feat_pad_index': FEAT.pad_index, + 'device': device, + 'path': base_path + }) + model = BiaffineDependencyModel(**args) + model.load_pretrained(WORD.embed).to(device) + parser_supar = BiaffineDependencyParserSupar(args, model, transform) + + ################################################################################################################ + # TRAIN + ################################################################################################################ + args = Config() + args.update({ + 'train': self.corpus.train, + 'dev': self.corpus.dev, + 'test': self.corpus.test + }) + parser_supar.transform.train() + parser_supar.args.clip = clip + parser_supar.args.punct = punct + parser_supar.args.tree = tree + parser_supar.args.proj = proj + if dist.is_initialized(): + batch_size = batch_size // dist.get_world_size() + logger.info("Loading the data") + train = Dataset(parser_supar.transform, self.corpus.train, **args) + dev = Dataset(parser_supar.transform, self.corpus.dev) + test = Dataset(parser_supar.transform, self.corpus.test) + train.build(batch_size, buckets, True, dist.is_initialized()) + dev.build(batch_size, buckets) + test.build(batch_size, buckets) + logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") + + logger.info(f"{parser_supar.model}\n") + if dist.is_initialized(): + parser_supar.model = DDP(parser_supar.model, + device_ids=[dist.get_rank()], + find_unused_parameters=True) + parser_supar.optimizer = Adam(parser_supar.model.parameters(), + lr, + (mu, nu), + epsilon) + parser_supar.scheduler = ExponentialLR(parser_supar.optimizer, decay ** (1 / decay_steps)) + + elapsed = timedelta() + best_e, best_metric = 1, Metric() + + for epoch in range(1, max_epochs + 1): + start = datetime.now() + + logger.info(f"Epoch {epoch} / {max_epochs}:") + parser_supar._train(train.loader) + loss, dev_metric = parser_supar._evaluate(dev.loader) + logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}") + loss, test_metric = parser_supar._evaluate(test.loader) + logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}") + + t = datetime.now() - start + # save the model if it is the best so far + if dev_metric > best_metric: + best_e, best_metric = epoch, dev_metric + if is_master(): + parser_supar.save(base_path) + logger.info(f"{t}s elapsed (saved)\n") + else: + logger.info(f"{t}s elapsed\n") + elapsed += t + if epoch - best_e >= patience: + break + loss, metric = parser_supar.load(base_path)._evaluate(test.loader) + + logger.info(f"Epoch {best_e} saved") + logger.info(f"{'dev:':6} - {best_metric}") + logger.info(f"{'test:':6} - {metric}") + logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch") diff --git a/underthesea/utils/__init__.py b/underthesea/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/underthesea/utils/sp_alg.py b/underthesea/utils/sp_alg.py new file mode 100644 index 00000000..549d56bd --- /dev/null +++ b/underthesea/utils/sp_alg.py @@ -0,0 +1,612 @@ +# -*- coding: utf-8 -*- + +import torch +from underthesea.utils.sp_fn import pad, stripe + + +def kmeans(x, k, max_it=32): + r""" + KMeans algorithm for clustering the sentences by length. + + Args: + x (list[int]): + The list of sentence lengths. + k (int): + The number of clusters. + This is an approximate value. The final number of clusters can be less or equal to `k`. + max_it (int): + Maximum number of iterations. + If centroids does not converge after several iterations, the algorithm will be early stopped. + + Returns: + list[float], list[list[int]]: + The first list contains average lengths of sentences in each cluster. + The second is the list of clusters holding indices of data points. + + Examples: + >>> x = torch.randint(10,20,(10,)).tolist() + >>> x + [15, 10, 17, 11, 18, 13, 17, 19, 18, 14] + >>> centroids, clusters = kmeans(x, 3) + >>> centroids + [10.5, 14.0, 17.799999237060547] + >>> clusters + [[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]] + """ + + # the number of clusters must not be greater than the number of datapoints + x, k = torch.tensor(x, dtype=torch.float), min(len(x), k) + # collect unique datapoints + d = x.unique() + # initialize k centroids randomly + c = d[torch.randperm(len(d))[:k]] + # assign each datapoint to the cluster with the closest centroid + dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1) + + for _ in range(max_it): + # if an empty cluster is encountered, + # choose the farthest datapoint from the biggest cluster and move that the empty one + mask = torch.arange(k).unsqueeze(-1).eq(y) + none = torch.where(~mask.any(-1))[0].tolist() + while len(none) > 0: + for i in none: + # the biggest cluster + b = torch.where(mask[mask.sum(-1).argmax()])[0] + # the datapoint farthest from the centroid of cluster b + f = dists[b].argmax() + # update the assigned cluster of f + y[b[f]] = i + # re-calculate the mask + mask = torch.arange(k).unsqueeze(-1).eq(y) + none = torch.where(~mask.any(-1))[0].tolist() + # update the centroids + c, old = (x * mask).sum(-1) / mask.sum(-1), c + # re-assign all datapoints to clusters + dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1) + # stop iteration early if the centroids converge + if c.equal(old): + break + # assign all datapoints to the new-generated clusters + # the empty ones are discarded + assigned = y.unique().tolist() + # get the centroids of the assigned clusters + centroids = c[assigned].tolist() + # map all values of datapoints to buckets + clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned] + + return centroids, clusters + + +# flake8: noqa: C901 +def tarjan(sequence): + r""" + Tarjan algorithm for finding Strongly Connected Components (SCCs) of a graph. + + Args: + sequence (list): + List of head indices. + + Yields: + A list of indices that make up a SCC. All self-loops are ignored. + + Examples: + >>> next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle + [2, 5, 1] + """ + + sequence = [-1] + sequence + # record the search order, i.e., the timestep + dfn = [-1] * len(sequence) + # record the the smallest timestep in a SCC + low = [-1] * len(sequence) + # push the visited into the stack + stack, onstack = [], [False] * len(sequence) + + def connect(i, timestep): + dfn[i] = low[i] = timestep[0] + timestep[0] += 1 + stack.append(i) + onstack[i] = True + + for j, head in enumerate(sequence): + if head != i: + continue + if dfn[j] == -1: + yield from connect(j, timestep) + low[i] = min(low[i], low[j]) + elif onstack[j]: + low[i] = min(low[i], dfn[j]) + + # a SCC is completed + if low[i] == dfn[i]: + cycle = [stack.pop()] + while cycle[-1] != i: + onstack[cycle[-1]] = False + cycle.append(stack.pop()) + onstack[i] = False + # ignore the self-loop + if len(cycle) > 1: + yield cycle + + timestep = [0] + for i in range(len(sequence)): + if dfn[i] == -1: + yield from connect(i, timestep) + + +def chuliu_edmonds(s): + r""" + ChuLiu/Edmonds algorithm for non-projective decoding. + + Some code is borrowed from `tdozat's implementation`_. + Descriptions of notations and formulas can be found in + `Non-projective Dependency Parsing using Spanning Tree Algorithms`_. + + Notes: + The algorithm does not guarantee to parse a single-root tree. + + References: + - Ryan McDonald, Fernando Pereira, Kiril Ribarov and Jan Hajic. 2005. + `Non-projective Dependency Parsing using Spanning Tree Algorithms`_. + + Args: + s (~torch.Tensor): ``[seq_len, seq_len]``. + Scores of all dependent-head pairs. + + Returns: + ~torch.Tensor: + A tensor with shape ``[seq_len]`` for the resulting non-projective parse tree. + + .. _tdozat's implementation: + https://github.com/tdozat/Parser-v3 + .. _Non-projective Dependency Parsing using Spanning Tree Algorithms: + https://www.aclweb.org/anthology/H05-1066/ + """ + + s[0, 1:] = float('-inf') + # prevent self-loops + s.diagonal()[1:].fill_(float('-inf')) + # select heads with highest scores + tree = s.argmax(-1) + # return the cycle finded by tarjan algorithm lazily + cycle = next(tarjan(tree.tolist()[1:]), None) + # if the tree has no cycles, then it is a MST + if not cycle: + return tree + # indices of cycle in the original tree + cycle = torch.tensor(cycle) + # indices of noncycle in the original tree + noncycle = torch.ones(len(s)).index_fill_(0, cycle, 0) + noncycle = torch.where(noncycle.gt(0))[0] + + def contract(s): + # heads of cycle in original tree + cycle_heads = tree[cycle] + # scores of cycle in original tree + s_cycle = s[cycle, cycle_heads] + + # calculate the scores of cycle's potential dependents + # s(c->x) = max(s(x'->x)), x in noncycle and x' in cycle + s_dep = s[noncycle][:, cycle] + # find the best cycle head for each noncycle dependent + deps = s_dep.argmax(1) + # calculate the scores of cycle's potential heads + # s(x->c) = max(s(x'->x) - s(a(x')->x') + s(cycle)), x in noncycle and x' in cycle + # a(v) is the predecessor of v in cycle + # s(cycle) = sum(s(a(v)->v)) + s_head = s[cycle][:, noncycle] - s_cycle.view(-1, 1) + s_cycle.sum() + # find the best noncycle head for each cycle dependent + heads = s_head.argmax(0) + + contracted = torch.cat((noncycle, torch.tensor([-1]))) + # calculate the scores of contracted graph + s = s[contracted][:, contracted] + # set the contracted graph scores of cycle's potential dependents + s[:-1, -1] = s_dep[range(len(deps)), deps] + # set the contracted graph scores of cycle's potential heads + s[-1, :-1] = s_head[heads, range(len(heads))] + + return s, heads, deps + + # keep track of the endpoints of the edges into and out of cycle for reconstruction later + s, heads, deps = contract(s) + + # y is the contracted tree + y = chuliu_edmonds(s) + # exclude head of cycle from y + y, cycle_head = y[:-1], y[-1] + + # fix the subtree with no heads coming from the cycle + # len(y) denotes heads coming from the cycle + subtree = y < len(y) + # add the nodes to the new tree + tree[noncycle[subtree]] = noncycle[y[subtree]] + # fix the subtree with heads coming from the cycle + subtree = ~subtree + # add the nodes to the tree + tree[noncycle[subtree]] = cycle[deps[subtree]] + # fix the root of the cycle + cycle_root = heads[cycle_head] + # break the cycle and add the root of the cycle to the tree + tree[cycle[cycle_root]] = noncycle[cycle_head] + + return tree + + +def mst(scores, mask, multiroot=False): + r""" + MST algorithm for decoding non-pojective trees. + This is a wrapper for ChuLiu/Edmonds algorithm. + + The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots, + If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find + best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds. + Otherwise the resulting trees are directly taken as the final outputs. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all dependent-head pairs. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask to avoid parsing over padding tokens. + The first column serving as pseudo words for roots should be ``False``. + muliroot (bool): + Ensures to parse a single-root tree If ``False``. + + Returns: + ~torch.Tensor: + A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees. + + Examples: + >>> scores = torch.tensor([[[-11.9436, -13.1464, -6.4789, -13.8917], + [-60.6957, -60.2866, -48.6457, -63.8125], + [-38.1747, -49.9296, -45.2733, -49.5571], + [-19.7504, -23.9066, -9.9139, -16.2088]]]) + >>> scores[:, 0, 1:] = float('-inf') + >>> scores.diagonal(0, 1, 2)[1:].fill_(float('-inf')) + >>> mask = torch.tensor([[False, True, True, True]]) + >>> mst(scores, mask) + tensor([[0, 2, 0, 2]]) + """ + + batch_size, seq_len, _ = scores.shape + scores = scores.cpu().unbind() + + preds = [] + for i, length in enumerate(mask.sum(1).tolist()): + s = scores[i][:length + 1, :length + 1] + tree = chuliu_edmonds(s) + roots = torch.where(tree[1:].eq(0))[0] + 1 + if not multiroot and len(roots) > 1: + s_root = s[:, 0] + s_best = float('-inf') + s = s.index_fill(1, torch.tensor(0), float('-inf')) + for root in roots: + s[:, 0] = float('-inf') + s[root, 0] = s_root[root] + t = chuliu_edmonds(s) + s_tree = s[1:].gather(1, t[1:].unsqueeze(-1)).sum() + if s_tree > s_best: + s_best, tree = s_tree, t + preds.append(tree) + + return pad(preds, total_length=seq_len).to(mask.device) + + +def eisner(scores, mask): + r""" + First-order Eisner algorithm for projective decoding. + + References: + - Ryan McDonald, Koby Crammer and Fernando Pereira. 2005. + `Online Large-Margin Training of Dependency Parsers`_. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all dependent-head pairs. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask to avoid parsing over padding tokens. + The first column serving as pseudo words for roots should be ``False``. + + Returns: + ~torch.Tensor: + A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees. + + Examples: + >>> scores = torch.tensor([[[-13.5026, -18.3700, -13.0033, -16.6809], + [-36.5235, -28.6344, -28.4696, -31.6750], + [ -2.9084, -7.4825, -1.4861, -6.8709], + [-29.4880, -27.6905, -26.1498, -27.0233]]]) + >>> mask = torch.tensor([[False, True, True, True]]) + >>> eisner(scores, mask) + tensor([[0, 2, 0, 2]]) + + .. _Online Large-Margin Training of Dependency Parsers: + https://www.aclweb.org/anthology/P05-1012/ + """ + + lens = mask.sum(1) + batch_size, seq_len, _ = scores.shape + scores = scores.permute(2, 1, 0) + s_i = torch.full_like(scores, float('-inf')) + s_c = torch.full_like(scores, float('-inf')) + p_i = scores.new_zeros(seq_len, seq_len, batch_size).long() + p_c = scores.new_zeros(seq_len, seq_len, batch_size).long() + s_c.diagonal().fill_(0) + + for w in range(1, seq_len): + n = seq_len - w + starts = p_i.new_tensor(range(n)).unsqueeze(0) + # ilr = C(i->r) + C(j->r+1) + ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) + # [batch_size, n, w] + il = ir = ilr.permute(2, 0, 1) + # I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j + il_span, il_path = il.max(-1) + s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w)) + p_i.diagonal(-w).copy_(il_path + starts) + # I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j + ir_span, ir_path = ir.max(-1) + s_i.diagonal(w).copy_(ir_span + scores.diagonal(w)) + p_i.diagonal(w).copy_(ir_path + starts) + + # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j + cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) + cl_span, cl_path = cl.permute(2, 0, 1).max(-1) + s_c.diagonal(-w).copy_(cl_span) + p_c.diagonal(-w).copy_(cl_path + starts) + # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j + cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) + cr_span, cr_path = cr.permute(2, 0, 1).max(-1) + s_c.diagonal(w).copy_(cr_span) + s_c[0, w][lens.ne(w)] = float('-inf') + p_c.diagonal(w).copy_(cr_path + starts + 1) + + def backtrack(p_i, p_c, heads, i, j, complete): + if i == j: + return + if complete: + r = p_c[i, j] + backtrack(p_i, p_c, heads, i, r, False) + backtrack(p_i, p_c, heads, r, j, True) + else: + r, heads[j] = p_i[i, j], i + i, j = sorted((i, j)) + backtrack(p_i, p_c, heads, i, r, True) + backtrack(p_i, p_c, heads, j, r + 1, True) + + preds = [] + p_c = p_c.permute(2, 0, 1).cpu() + p_i = p_i.permute(2, 0, 1).cpu() + for i, length in enumerate(lens.tolist()): + heads = p_c.new_zeros(length + 1, dtype=torch.long) + backtrack(p_i[i], p_c[i], heads, 0, length, True) + preds.append(heads.to(mask.device)) + + return pad(preds, total_length=seq_len).to(mask.device) + + +def eisner2o(scores, mask): + r""" + Second-order Eisner algorithm for projective decoding. + This is an extension of the first-order one that further incorporates sibling scores into tree scoring. + + References: + - Ryan McDonald and Fernando Pereira. 2006. + `Online Learning of Approximate Dependency Parsing Algorithms`_. + + Args: + scores (~torch.Tensor, ~torch.Tensor): + A tuple of two tensors representing the first-order and second-order scores repectively. + The first (``[batch_size, seq_len, seq_len]``) holds scores of all dependent-head pairs. + The second (``[batch_size, seq_len, seq_len, seq_len]``) holds scores of all dependent-head-sibling triples. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask to avoid parsing over padding tokens. + The first column serving as pseudo words for roots should be ``False``. + + Returns: + ~torch.Tensor: + A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees. + + Examples: + >>> s_arc = torch.tensor([[[ -2.8092, -7.9104, -0.9414, -5.4360], + [-10.3494, -7.9298, -3.6929, -7.3985], + [ 1.1815, -3.8291, 2.3166, -2.7183], + [ -3.9776, -3.9063, -1.6762, -3.1861]]]) + >>> s_sib = torch.tensor([[[[ 0.4719, 0.4154, 1.1333, 0.6946], + [ 1.1252, 1.3043, 2.1128, 1.4621], + [ 0.5974, 0.5635, 1.0115, 0.7550], + [ 1.1174, 1.3794, 2.2567, 1.4043]], + [[-2.1480, -4.1830, -2.5519, -1.8020], + [-1.2496, -1.7859, -0.0665, -0.4938], + [-2.6171, -4.0142, -2.9428, -2.2121], + [-0.5166, -1.0925, 0.5190, 0.1371]], + [[ 0.5827, -1.2499, -0.0648, -0.0497], + [ 1.4695, 0.3522, 1.5614, 1.0236], + [ 0.4647, -0.7996, -0.3801, 0.0046], + [ 1.5611, 0.3875, 1.8285, 1.0766]], + [[-1.3053, -2.9423, -1.5779, -1.2142], + [-0.1908, -0.9699, 0.3085, 0.1061], + [-1.6783, -2.8199, -1.8853, -1.5653], + [ 0.3629, -0.3488, 0.9011, 0.5674]]]]) + >>> mask = torch.tensor([[False, True, True, True]]) + >>> eisner2o((s_arc, s_sib), mask) + tensor([[0, 2, 0, 2]]) + + .. _Online Learning of Approximate Dependency Parsing Algorithms: + https://www.aclweb.org/anthology/E06-1011/ + """ + + # the end position of each sentence in a batch + lens = mask.sum(1) + s_arc, s_sib = scores + batch_size, seq_len, _ = s_arc.shape + # [seq_len, seq_len, batch_size] + s_arc = s_arc.permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size] + s_sib = s_sib.permute(2, 1, 3, 0) + s_i = torch.full_like(s_arc, float('-inf')) + s_s = torch.full_like(s_arc, float('-inf')) + s_c = torch.full_like(s_arc, float('-inf')) + p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long() + p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long() + p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long() + s_c.diagonal().fill_(0) + + for w in range(1, seq_len): + # n denotes the number of spans to iterate, + # from span (0, w) to span (n, n+w) given width w + n = seq_len - w + starts = p_i.new_tensor(range(n)).unsqueeze(0) + # I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j | + # C(j->j) + C(i->j-1)) + # + s(j->i) + # [n, w, batch_size] + il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0) + il += stripe(s_sib[range(w, n + w), range(n)], n, w, (0, 1)) + # [n, 1, batch_size] + il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1)) + # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf + il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1) + il_span, il_path = il.permute(2, 0, 1).max(-1) + s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w)) + p_i.diagonal(-w).copy_(il_path + starts + 1) + # I(i->j) = max(I(i->r) + S(i->r, j), i < r < j | + # C(i->i) + C(j->i+1)) + # + s(i->j) + # [n, w, batch_size] + ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0) + ir += stripe(s_sib[range(n), range(w, n + w)], n, w) + ir[0] = float('-inf') + # [n, 1, batch_size] + ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1)) + ir[:, 0] = ir0.squeeze(1) + ir_span, ir_path = ir.permute(2, 0, 1).max(-1) + s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w)) + p_i.diagonal(w).copy_(ir_path + starts) + + # [n, w, batch_size] + slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) + slr_span, slr_path = slr.permute(2, 0, 1).max(-1) + # S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j + s_s.diagonal(-w).copy_(slr_span) + p_s.diagonal(-w).copy_(slr_path + starts) + # S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j + s_s.diagonal(w).copy_(slr_span) + p_s.diagonal(w).copy_(slr_path + starts) + + # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j + cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) + cl_span, cl_path = cl.permute(2, 0, 1).max(-1) + s_c.diagonal(-w).copy_(cl_span) + p_c.diagonal(-w).copy_(cl_path + starts) + # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j + cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) + cr_span, cr_path = cr.permute(2, 0, 1).max(-1) + s_c.diagonal(w).copy_(cr_span) + # disable multi words to modify the root + s_c[0, w][lens.ne(w)] = float('-inf') + p_c.diagonal(w).copy_(cr_path + starts + 1) + + def backtrack(p_i, p_s, p_c, heads, i, j, flag): + if i == j: + return + if flag == 'c': + r = p_c[i, j] + backtrack(p_i, p_s, p_c, heads, i, r, 'i') + backtrack(p_i, p_s, p_c, heads, r, j, 'c') + elif flag == 's': + r = p_s[i, j] + i, j = sorted((i, j)) + backtrack(p_i, p_s, p_c, heads, i, r, 'c') + backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c') + elif flag == 'i': + r, heads[j] = p_i[i, j], i + if r == i: + r = i + 1 if i < j else i - 1 + backtrack(p_i, p_s, p_c, heads, j, r, 'c') + else: + backtrack(p_i, p_s, p_c, heads, i, r, 'i') + backtrack(p_i, p_s, p_c, heads, r, j, 's') + + preds = [] + p_i = p_i.permute(2, 0, 1).cpu() + p_s = p_s.permute(2, 0, 1).cpu() + p_c = p_c.permute(2, 0, 1).cpu() + for i, length in enumerate(lens.tolist()): + heads = p_c.new_zeros(length + 1, dtype=torch.long) + backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c') + preds.append(heads.to(mask.device)) + + return pad(preds, total_length=seq_len).to(mask.device) + + +def cky(scores, mask): + r""" + The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees. + + References: + - Yu Zhang, Houquan Zhou and Zhenghua Li. 2020. + `Fast and Accurate Neural CRF Constituency Parsing`_. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all candidate constituents. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask to avoid parsing over padding tokens. + For each square matrix in a batch, the positions except upper triangular part should be masked out. + + Returns: + Sequences of factorized predicted bracketed trees that are traversed in pre-order. + + Examples: + >>> scores = torch.tensor([[[ 2.5659, 1.4253, -2.5272, 3.3011], + [ 1.3687, -0.5869, 1.0011, 3.3020], + [ 1.2297, 0.4862, 1.1975, 2.5387], + [-0.0511, -1.2541, -0.7577, 0.2659]]]) + >>> mask = torch.tensor([[[False, True, True, True], + [False, False, True, True], + [False, False, False, True], + [False, False, False, False]]]) + >>> cky(scores, mask) + [[(0, 3), (0, 1), (1, 3), (1, 2), (2, 3)]] + + .. _Cocke-Kasami-Younger: + https://en.wikipedia.org/wiki/CYK_algorithm + .. _Fast and Accurate Neural CRF Constituency Parsing: + https://www.ijcai.org/Proceedings/2020/560/ + """ + + lens = mask[:, 0].sum(-1) + scores = scores.permute(1, 2, 0) + seq_len, seq_len, batch_size = scores.shape + s = scores.new_zeros(seq_len, seq_len, batch_size) + p = scores.new_zeros(seq_len, seq_len, batch_size).long() + + for w in range(1, seq_len): + n = seq_len - w + starts = p.new_tensor(range(n)).unsqueeze(0) + + if w == 1: + s.diagonal(w).copy_(scores.diagonal(w)) + continue + # [n, w, batch_size] + s_span = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0) + # [batch_size, n, w] + s_span = s_span.permute(2, 0, 1) + # [batch_size, n] + s_span, p_span = s_span.max(-1) + s.diagonal(w).copy_(s_span + scores.diagonal(w)) + p.diagonal(w).copy_(p_span + starts + 1) + + def backtrack(p, i, j): + if j == i + 1: + return [(i, j)] + split = p[i][j] + ltree = backtrack(p, i, split) + rtree = backtrack(p, split, j) + return [(i, j)] + ltree + rtree + + p = p.permute(2, 0, 1).tolist() + trees = [backtrack(p[i], 0, length) for i, length in enumerate(lens.tolist())] + + return trees diff --git a/underthesea/utils/sp_common.py b/underthesea/utils/sp_common.py new file mode 100644 index 00000000..c0ff6e6b --- /dev/null +++ b/underthesea/utils/sp_common.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +pad = '' +unk = '' +bos = '' +eos = '' diff --git a/underthesea/utils/sp_config.py b/underthesea/utils/sp_config.py new file mode 100644 index 00000000..8734d5ab --- /dev/null +++ b/underthesea/utils/sp_config.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +from ast import literal_eval +from configparser import ConfigParser + + +class Config(object): + + def __init__(self, conf=None, **kwargs): + super(Config, self).__init__() + + config = ConfigParser() + config.read(conf or []) + self.update({**dict((name, literal_eval(value)) + for section in config.sections() + for name, value in config.items(section)), + **kwargs}) + + def __repr__(self): + s = line = "-" * 15 + "-+-" + "-" * 25 + "\n" + s += f"{'Param':15} | {'Value':^25}\n" + line + for name, value in vars(self).items(): + s += f"{name:15} | {str(value):^25}\n" + s += line + + return s + + def __getitem__(self, key): + return getattr(self, key) + + def __getstate__(self): + return vars(self) + + def __setstate__(self, state): + self.__dict__.update(state) + + def keys(self): + return vars(self).keys() + + def items(self): + return vars(self).items() + + def update(self, kwargs): + for key in ('self', 'cls', '__class__'): + kwargs.pop(key, None) + kwargs.update(kwargs.pop('kwargs', dict())) + for name, value in kwargs.items(): + setattr(self, name, value) + + return self + + def pop(self, key, val=None): + return self.__dict__.pop(key, val) diff --git a/underthesea/utils/sp_data.py b/underthesea/utils/sp_data.py new file mode 100644 index 00000000..58d5d8e7 --- /dev/null +++ b/underthesea/utils/sp_data.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.distributed as dist + +from underthesea.utils.sp_alg import kmeans + + +class Dataset(torch.utils.data.Dataset): + r""" + Dataset that is compatible with :class:`torch.utils.data.Dataset`. + This serves as a wrapper for manipulating all data fields + with the operating behaviours defined in :class:`Transform`. + The data fields of all the instantiated sentences can be accessed as an attribute of the dataset. + + Args: + transform (Transform): + An instance of :class:`Transform` and its derivations. + The instance holds a series of loading and processing behaviours with regard to the specfic data format. + data (list[list] or str): + A list of instances or a filename. + This will be passed into :meth:`transform.load`. + kwargs (dict): + Keyword arguments that will be passed into :meth:`transform.load` together with `data` + to control the loading behaviour. + + Attributes: + transform (Transform): + An instance of :class:`Transform`. + sentences (list[Sentence]): + A list of sentences loaded from the data. + Each sentence includes fields obeying the data format defined in ``transform``. + """ + + def __init__(self, transform, data, **kwargs): + super(Dataset, self).__init__() + + self.transform = transform + self.sentences = transform.load(data, **kwargs) + + def __repr__(self): + s = f"{self.__class__.__name__}(" + s += f"n_sentences={len(self.sentences)}" + if hasattr(self, 'loader'): + s += f", n_batches={len(self.loader)}" + if hasattr(self, 'buckets'): + s += f", n_buckets={len(self.buckets)}" + s += ")" + + return s + + def __len__(self): + return len(self.sentences) + + def __getitem__(self, index): + if not hasattr(self, 'fields'): + raise RuntimeError("The fields are not numericalized. Please build the dataset first.") + for d in self.fields.values(): + yield d[index] + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + return [getattr(sentence, name) for sentence in self.sentences] + + def __setattr__(self, name, value): + if 'sentences' in self.__dict__ and name in self.sentences[0]: + # restore the order of sequences in the buckets + indices = torch.tensor([i + for bucket in self.buckets.values() + for i in bucket]).argsort() + for index, sentence in zip(indices, self.sentences): + setattr(sentence, name, value[index]) + else: + self.__dict__[name] = value + + def __getstate__(self): + # only pickle the Transform object and sentences + return {'transform': self.transform, 'sentences': self.sentences} + + def __setstate__(self, state): + self.__dict__.update(state) + + def collate_fn(self, batch): + return {f: d for f, d in zip(self.fields.keys(), zip(*batch))} + + def build(self, batch_size, n_buckets=1, shuffle=False, distributed=False): + # numericalize all fields + self.fields = self.transform(self.sentences) + # NOTE: the final bucket count is roughly equal to n_buckets + self.lengths = [len(i) for i in self.fields[next(iter(self.fields))]] + self.buckets = dict(zip(*kmeans(self.lengths, n_buckets))) + self.loader = DataLoader(dataset=self, + batch_sampler=Sampler(buckets=self.buckets, + batch_size=batch_size, + shuffle=shuffle, + distributed=distributed), + collate_fn=self.collate_fn) + + +class DataLoader(torch.utils.data.DataLoader): + r""" + DataLoader, matching with :class:`Dataset`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __iter__(self): + for batch in super().__iter__(): + yield [f.compose(d) for f, d in batch.items()] + + +class Sampler(torch.utils.data.Sampler): + r""" + Sampler that supports for bucketization and token-level batchification. + + Args: + buckets (dict): + A dict that maps each centroid to indices of clustered sentences. + The centroid corresponds to the average length of all sentences in the bucket. + batch_size (int): + Token-level batch size. The resulting batch contains roughly the same number of tokens as ``batch_size``. + shuffle (bool): + If ``True``, the sampler will shuffle both buckets and samples in each bucket. Default: ``False``. + distributed (bool): + If ``True``, the sampler will be used in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` + that restricts data loading to a subset of the dataset. + Default: ``False``. + """ + + def __init__(self, buckets, batch_size, shuffle=False, distributed=False): + self.batch_size = batch_size + self.shuffle = shuffle + self.sizes, self.buckets = zip(*[(size, bucket) for size, bucket in buckets.items()]) + # number of chunks in each bucket, clipped by range [1, len(bucket)] + self.chunks = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1)) + for size, bucket in zip(self.sizes, self.buckets)] + + self.rank = dist.get_rank() if distributed else 0 + self.replicas = dist.get_world_size() if distributed else 1 + self.samples = sum(self.chunks) // self.replicas + self.epoch = 0 + + def __iter__(self): + g = torch.Generator() + g.manual_seed(self.epoch) + range_fn = torch.arange + # if shuffle, shuffle both the buckets and samples in each bucket + # for distributed training, make sure each process generte the same random sequence at each epoch + if self.shuffle: + def range_fn(x): + return torch.randperm(x, generator=g) + total, count = 0, 0 + # TODO: more elegant way to deal with uneven data, which we directly discard right now + for i in range_fn(len(self.buckets)).tolist(): + split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1 + for j in range(self.chunks[i])] + # DON'T use `torch.chunk` which may return wrong number of chunks + for batch in range_fn(len(self.buckets[i])).split(split_sizes): + if count == self.samples: + break + if total % self.replicas == self.rank: + count += 1 + yield [self.buckets[i][j] for j in batch.tolist()] + total += 1 + self.epoch += 1 + + def __len__(self): + return self.samples diff --git a/underthesea/utils/sp_embedding.py b/underthesea/utils/sp_embedding.py new file mode 100644 index 00000000..37128670 --- /dev/null +++ b/underthesea/utils/sp_embedding.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +import torch + + +class Embedding(object): + + def __init__(self, tokens, vectors, unk=None): + self.tokens = tokens + self.vectors = torch.tensor(vectors) + self.pretrained = {w: v for w, v in zip(tokens, vectors)} + self.unk = unk + + def __len__(self): + return len(self.tokens) + + def __contains__(self, token): + return token in self.pretrained + + @property + def dim(self): + return self.vectors.size(1) + + @property + def unk_index(self): + if self.unk is not None: + return self.tokens.index(self.unk) + else: + raise AttributeError + + @classmethod + def load(cls, path, unk=None): + with open(path, 'r') as f: + lines = [line for line in f] + splits = [line.split() for line in lines] + tokens, vectors = zip(*[(s[0], list(map(float, s[1:]))) + for s in splits]) + + return cls(tokens, vectors, unk=unk) diff --git a/underthesea/utils/sp_field.py b/underthesea/utils/sp_field.py new file mode 100644 index 00000000..1b200cd0 --- /dev/null +++ b/underthesea/utils/sp_field.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- + +from collections import Counter + +import torch + +from underthesea.utils.sp_fn import pad +from underthesea.utils.sp_vocab import Vocab + + +class RawField(object): + r""" + Defines a general datatype. + + A :class:`RawField` object does not assume any property of the datatype and + it holds parameters relating to how a datatype should be processed. + + Args: + name (str): + The name of the field. + fn (function): + The function used for preprocessing the examples. Default: ``None``. + """ + + def __init__(self, name, fn=None): + self.name = name + self.fn = fn + + def __repr__(self): + return f"({self.name}): {self.__class__.__name__}()" + + def preprocess(self, sequence): + return self.fn(sequence) if self.fn is not None else sequence + + def transform(self, sequences): + return [self.preprocess(seq) for seq in sequences] + + def compose(self, sequences): + return sequences + + +class Field(RawField): + r""" + Defines a datatype together with instructions for converting to :class:`~torch.Tensor`. + :class:`Field` models common text processing datatypes that can be represented by tensors. + It holds a :class:`Vocab` object that defines the set of possible values + for elements of the field and their corresponding numerical representations. + The :class:`Field` object also holds other parameters relating to how a datatype + should be numericalized, such as a tokenization method. + + Args: + name (str): + The name of the field. + pad_token (str): + The string token used as padding. Default: ``None``. + unk_token (str): + The string token used to represent OOV words. Default: ``None``. + bos_token (str): + A token that will be prepended to every example using this field, or ``None`` for no `bos_token`. + Default: ``None``. + eos_token (str): + A token that will be appended to every example using this field, or ``None`` for no `eos_token`. + lower (bool): + Whether to lowercase the text in this field. Default: ``False``. + use_vocab (bool): + Whether to use a :class:`Vocab` object. If ``False``, the data in this field should already be numerical. + Default: ``True``. + tokenize (function): + The function used to tokenize strings using this field into sequential examples. Default: ``None``. + fn (function): + The function used for preprocessing the examples. Default: ``None``. + """ + + def __init__(self, name, pad=None, unk=None, bos=None, eos=None, + lower=False, use_vocab=True, tokenize=None, fn=None): + self.name = name + self.pad = pad + self.unk = unk + self.bos = bos + self.eos = eos + self.lower = lower + self.use_vocab = use_vocab + self.tokenize = tokenize + self.fn = fn + + self.specials = [token for token in [pad, unk, bos, eos] + if token is not None] + + def __repr__(self): + s, params = f"({self.name}): {self.__class__.__name__}(", [] + if self.pad is not None: + params.append(f"pad={self.pad}") + if self.unk is not None: + params.append(f"unk={self.unk}") + if self.bos is not None: + params.append(f"bos={self.bos}") + if self.eos is not None: + params.append(f"eos={self.eos}") + if self.lower: + params.append(f"lower={self.lower}") + if not self.use_vocab: + params.append(f"use_vocab={self.use_vocab}") + s += ", ".join(params) + s += ")" + + return s + + @property + def pad_index(self): + if self.pad is None: + return 0 + if hasattr(self, 'vocab'): + return self.vocab[self.pad] + return self.specials.index(self.pad) + + @property + def unk_index(self): + if self.unk is None: + return 0 + if hasattr(self, 'vocab'): + return self.vocab[self.unk] + return self.specials.index(self.unk) + + @property + def bos_index(self): + if hasattr(self, 'vocab'): + return self.vocab[self.bos] + return self.specials.index(self.bos) + + @property + def eos_index(self): + if hasattr(self, 'vocab'): + return self.vocab[self.eos] + return self.specials.index(self.eos) + + @property + def device(self): + return 'cuda' if torch.cuda.is_available() else 'cpu' + + def preprocess(self, sequence): + r""" + Loads a single example using this field, tokenizing if necessary. + The sequence will be first passed to ``fn`` if available. + If ``tokenize`` is not None, the input will be tokenized. + Then the input will be lowercased optionally. + + Args: + sequence (list): + The sequence to be preprocessed. + + Returns: + A list of preprocessed sequence. + """ + + if self.fn is not None: + sequence = self.fn(sequence) + if self.tokenize is not None: + sequence = self.tokenize(sequence) + if self.lower: + sequence = [str.lower(token) for token in sequence] + + return sequence + + def build(self, dataset, min_freq=1, embed=None): + r""" + Constructs a :class:`Vocab` object for this field from the dataset. + If the vocabulary has already existed, this function will have no effect. + + Args: + dataset (Dataset): + A :class:`Dataset` object. One of the attributes should be named after the name of this field. + min_freq (int): + The minimum frequency needed to include a token in the vocabulary. Default: 1. + embed (Embedding): + An Embedding object, words in which will be extended to the vocabulary. Default: ``None``. + """ + + if hasattr(self, 'vocab'): + return + sequences = getattr(dataset, self.name) + counter = Counter(token + for seq in sequences + for token in self.preprocess(seq)) + self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index) + + if not embed: + self.embed = None + else: + tokens = self.preprocess(embed.tokens) + # if the `unk` token has existed in the pretrained, + # then replace it with a self-defined one + if embed.unk: + tokens[embed.unk_index] = self.unk + + self.vocab.extend(tokens) + self.embed = torch.zeros(len(self.vocab), embed.dim) + self.embed[self.vocab[tokens]] = embed.vectors + self.embed /= torch.std(self.embed) + + def transform(self, sequences): + r""" + Turns a list of sequences that use this field into tensors. + + Each sequence is first preprocessed and then numericalized if needed. + + Args: + sequences (list[list[str]]): + A list of sequences. + + Returns: + A list of tensors transformed from the input sequences. + """ + + sequences = [self.preprocess(seq) for seq in sequences] + if self.use_vocab: + sequences = [self.vocab[seq] for seq in sequences] + if self.bos: + sequences = [[self.bos_index] + seq for seq in sequences] + if self.eos: + sequences = [seq + [self.eos_index] for seq in sequences] + sequences = [torch.tensor(seq) for seq in sequences] + + return sequences + + def compose(self, sequences): + r""" + Composes a batch of sequences into a padded tensor. + + Args: + sequences (list[~torch.Tensor]): + A list of tensors. + + Returns: + A padded tensor converted to proper device. + """ + + return pad(sequences, self.pad_index).to(self.device) + + +class SubwordField(Field): + r""" + A field that conducts tokenization and numericalization over each token rather the sequence. + + This is customized for models requiring character/subword-level inputs, e.g., CharLSTM and BERT. + + Args: + fix_len (int): + A fixed length that all subword pieces will be padded to. + This is used for truncating the subword pieces that exceed the length. + To save the memory, the final length will be the smaller value + between the max length of subword pieces in a batch and `fix_len`. + + Examples: + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-cased') + >>> field = SubwordField('bert', + pad=tokenizer.pad_token, + unk=tokenizer.unk_token, + bos=tokenizer.cls_token, + eos=tokenizer.sep_token, + fix_len=20, + tokenize=tokenizer.tokenize) + >>> field.vocab = tokenizer.get_vocab() # no need to re-build the vocab + >>> field.transform([['This', 'field', 'performs', 'token-level', 'tokenization']])[0] + tensor([[ 101, 0, 0], + [ 1188, 0, 0], + [ 1768, 0, 0], + [10383, 0, 0], + [22559, 118, 1634], + [22559, 2734, 0], + [ 102, 0, 0]]) + """ + + def __init__(self, *args, **kwargs): + self.fix_len = kwargs.pop('fix_len') if 'fix_len' in kwargs else 0 + super().__init__(*args, **kwargs) + + def build(self, dataset, min_freq=1, embed=None): + if hasattr(self, 'vocab'): + return + sequences = getattr(dataset, self.name) + counter = Counter(piece + for seq in sequences + for token in seq + for piece in self.preprocess(token)) + self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index) + + if not embed: + self.embed = None + else: + tokens = self.preprocess(embed.tokens) + # if the `unk` token has existed in the pretrained, + # then replace it with a self-defined one + if embed.unk: + tokens[embed.unk_index] = self.unk + + self.vocab.extend(tokens) + self.embed = torch.zeros(len(self.vocab), embed.dim) + self.embed[self.vocab[tokens]] = embed.vectors + + def transform(self, sequences): + sequences = [[self.preprocess(token) for token in seq] + for seq in sequences] + if self.fix_len <= 0: + self.fix_len = max(len(token) for seq in sequences for token in seq) + if self.use_vocab: + sequences = [[[self.vocab[i] for i in token] if token else [self.unk_index] for token in seq] + for seq in sequences] + if self.bos: + sequences = [[[self.bos_index]] + seq for seq in sequences] + if self.eos: + sequences = [seq + [[self.eos_index]] for seq in sequences] + lens = [min(self.fix_len, max(len(ids) for ids in seq)) for seq in sequences] + sequences = [pad([torch.tensor(ids[:i]) for ids in seq], self.pad_index, i) + for i, seq in zip(lens, sequences)] + + return sequences + + +class ChartField(Field): + r""" + Field dealing with constituency trees. + + This field receives sequences of binarized trees factorized in pre-order, + and returns two tensors representing the bracketing trees and labels on each constituent respectively. + + Examples: + >>> sequence = [(0, 5, 'S'), (0, 4, 'S|<>'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP|<>'), + (2, 4, 'S+VP'), (2, 3, 'VP|<>'), (3, 4, 'NP'), (4, 5, 'S|<>')] + >>> spans, labels = field.transform([sequence])[0] # this example field is built from ptb + >>> spans + tensor([[False, True, False, False, True, True], + [False, False, True, False, True, False], + [False, False, False, True, True, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + [False, False, False, False, False, False]]) + >>> labels + tensor([[ 0, 37, 0, 0, 107, 79], + [ 0, 0, 120, 0, 112, 0], + [ 0, 0, 0, 120, 86, 0], + [ 0, 0, 0, 0, 37, 0], + [ 0, 0, 0, 0, 0, 107], + [ 0, 0, 0, 0, 0, 0]]) + """ + + def build(self, dataset, min_freq=1): + counter = Counter(label + for seq in getattr(dataset, self.name) + for i, j, label in self.preprocess(seq)) + + self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index) + + def transform(self, sequences): + sequences = [self.preprocess(seq) for seq in sequences] + spans, labels = [], [] + + for sequence in sequences: + seq_len = sequence[0][1] + 1 + span_chart = torch.full((seq_len, seq_len), self.pad_index, dtype=torch.bool) + label_chart = torch.full((seq_len, seq_len), self.pad_index, dtype=torch.long) + for i, j, label in sequence: + span_chart[i, j] = 1 + label_chart[i, j] = self.vocab[label] + spans.append(span_chart) + labels.append(label_chart) + + return list(zip(spans, labels)) + + def compose(self, sequences): + return [pad(i).to(self.device) for i in zip(*sequences)] diff --git a/underthesea/utils/sp_fn.py b/underthesea/utils/sp_fn.py new file mode 100644 index 00000000..6a678249 --- /dev/null +++ b/underthesea/utils/sp_fn.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +import unicodedata + + +def ispunct(token): + return all(unicodedata.category(char).startswith('P') + for char in token) + + +def isfullwidth(token): + return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] + for char in token) + + +def islatin(token): + return all('LATIN' in unicodedata.name(char) + for char in token) + + +def isdigit(token): + return all('DIGIT' in unicodedata.name(char) + for char in token) + + +def tohalfwidth(token): + return unicodedata.normalize('NFKC', token) + + +def stripe(x, n, w, offset=(0, 0), dim=1): + r""" + Returns a diagonal stripe of the tensor. + + Args: + x (~torch.Tensor): the input tensor with 2 or more dims. + n (int): the length of the stripe. + w (int): the width of the stripe. + offset (tuple): the offset of the first two dims. + dim (int): 0 if returns a horizontal stripe; 1 otherwise. + + Returns: + a diagonal stripe of the tensor. + + Examples: + >>> x = torch.arange(25).view(5, 5) + >>> x + tensor([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + >>> stripe(x, 2, 3, (1, 1)) + tensor([[ 6, 7, 8], + [12, 13, 14]]) + >>> stripe(x, 2, 3, dim=0) + tensor([[ 0, 5, 10], + [ 6, 11, 16]]) + """ + + x, seq_len = x.contiguous(), x.size(1) + stride, numel = list(x.stride()), x[0, 0].numel() + stride[0] = (seq_len + 1) * numel + stride[1] = (1 if dim == 1 else seq_len) * numel + return x.as_strided(size=(n, w, *x.shape[2:]), + stride=stride, + storage_offset=(offset[0] * seq_len + offset[1]) * numel) + + +def pad(tensors, padding_value=0, total_length=None): + size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) + for i in range(len(tensors[0].size()))] + if total_length is not None: + assert total_length >= size[1] + size[1] = total_length + out_tensor = tensors[0].data.new(*size).fill_(padding_value) + for i, tensor in enumerate(tensors): + out_tensor[i][[slice(0, i) for i in tensor.size()]] = tensor + return out_tensor diff --git a/underthesea/utils/sp_init.py b/underthesea/utils/sp_init.py new file mode 100644 index 00000000..d69a2033 --- /dev/null +++ b/underthesea/utils/sp_init.py @@ -0,0 +1,16 @@ +PRETRAINED = { + 'biaffine-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.char.zip', + 'biaffine-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.char.zip', + 'biaffine-dep-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.bert.zip', + 'biaffine-dep-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.bert.zip', + 'crfnp-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crfnp.dependency.char.zip', + 'crfnp-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crfnp.dependency.char.zip', + 'crf-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.dependency.char.zip', + 'crf-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.dependency.char.zip', + 'crf2o-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf2o.dependency.char.zip', + 'crf2o-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf2o.dependency.char.zip', + 'crf-con-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.char.zip', + 'crf-con-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.char.zip', + 'crf-con-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.bert.zip', + 'crf-con-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.bert.zip' +} diff --git a/underthesea/utils/sp_metric.py b/underthesea/utils/sp_metric.py new file mode 100644 index 00000000..92878400 --- /dev/null +++ b/underthesea/utils/sp_metric.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- + +from collections import Counter + + +class Metric(object): + + def __lt__(self, other): + return self.score < other + + def __le__(self, other): + return self.score <= other + + def __ge__(self, other): + return self.score >= other + + def __gt__(self, other): + return self.score > other + + @property + def score(self): + return 0. + + +class AttachmentMetric(Metric): + + def __init__(self, eps=1e-8): + super().__init__() + + self.eps = eps + + self.n = 0.0 + self.n_ucm = 0.0 + self.n_lcm = 0.0 + self.total = 0.0 + self.correct_arcs = 0.0 + self.correct_rels = 0.0 + + def __repr__(self): + s = f"UCM: {self.ucm:6.2%} LCM: {self.lcm:6.2%} " + s += f"UAS: {self.uas:6.2%} LAS: {self.las:6.2%}" + return s + + def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask): + lens = mask.sum(1) + arc_mask = arc_preds.eq(arc_golds) & mask + rel_mask = rel_preds.eq(rel_golds) & arc_mask + arc_mask_seq, rel_mask_seq = arc_mask[mask], rel_mask[mask] + + self.n += len(mask) + self.n_ucm += arc_mask.sum(1).eq(lens).sum().item() + self.n_lcm += rel_mask.sum(1).eq(lens).sum().item() + + self.total += len(arc_mask_seq) + self.correct_arcs += arc_mask_seq.sum().item() + self.correct_rels += rel_mask_seq.sum().item() + + @property + def score(self): + return self.las + + @property + def ucm(self): + return self.n_ucm / (self.n + self.eps) + + @property + def lcm(self): + return self.n_lcm / (self.n + self.eps) + + @property + def uas(self): + return self.correct_arcs / (self.total + self.eps) + + @property + def las(self): + return self.correct_rels / (self.total + self.eps) + + +class BracketMetric(Metric): + + def __init__(self, eps=1e-8): + super().__init__() + + self.n = 0.0 + self.n_ucm = 0.0 + self.n_lcm = 0.0 + self.utp = 0.0 + self.ltp = 0.0 + self.pred = 0.0 + self.gold = 0.0 + self.eps = eps + + def __call__(self, preds, golds): + for pred, gold in zip(preds, golds): + upred = Counter([(i, j) for i, j, label in pred]) + ugold = Counter([(i, j) for i, j, label in gold]) + utp = list((upred & ugold).elements()) + lpred = Counter(pred) + lgold = Counter(gold) + ltp = list((lpred & lgold).elements()) + self.n += 1 + self.n_ucm += len(utp) == len(pred) == len(gold) + self.n_lcm += len(ltp) == len(pred) == len(gold) + self.utp += len(utp) + self.ltp += len(ltp) + self.pred += len(pred) + self.gold += len(gold) + + def __repr__(self): + s = f"UCM: {self.ucm:6.2%} LCM: {self.lcm:6.2%} " + s += f"UP: {self.up:6.2%} UR: {self.ur:6.2%} UF: {self.uf:6.2%} " + s += f"LP: {self.lp:6.2%} LR: {self.lr:6.2%} LF: {self.lf:6.2%}" + + return s + + @property + def score(self): + return self.lf + + @property + def ucm(self): + return self.n_ucm / (self.n + self.eps) + + @property + def lcm(self): + return self.n_lcm / (self.n + self.eps) + + @property + def up(self): + return self.utp / (self.pred + self.eps) + + @property + def ur(self): + return self.utp / (self.gold + self.eps) + + @property + def uf(self): + return 2 * self.utp / (self.pred + self.gold + self.eps) + + @property + def lp(self): + return self.ltp / (self.pred + self.eps) + + @property + def lr(self): + return self.ltp / (self.gold + self.eps) + + @property + def lf(self): + return 2 * self.ltp / (self.pred + self.gold + self.eps) + + +class SpanMetric(Metric): + + def __init__(self, eps=1e-5): + super(SpanMetric, self).__init__() + + self.tp = 0.0 + self.pred = 0.0 + self.gold = 0.0 + self.eps = eps + + def __call__(self, preds, golds): + for pred, gold in zip(preds, golds): + pred, gold = set(pred), set(gold) + self.tp += len(pred & gold) + self.pred += len(pred) + self.gold += len(gold) + + def __repr__(self): + return f"P: {self.p:6.2%} R: {self.r:6.2%} F: {self.f:6.2%}" + + @property + def score(self): + return self.f + + @property + def p(self): + return self.tp / (self.pred + self.eps) + + @property + def r(self): + return self.tp / (self.gold + self.eps) + + @property + def f(self): + return 2 * self.p * self.r / (self.p + self.r + self.eps) diff --git a/underthesea/utils/sp_parallel.py b/underthesea/utils/sp_parallel.py new file mode 100644 index 00000000..e7281eda --- /dev/null +++ b/underthesea/utils/sp_parallel.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +import os +from random import Random + +import torch +import torch.distributed as dist +import torch.nn as nn + + +class DistributedDataParallel(nn.parallel.DistributedDataParallel): + + def __init__(self, module, **kwargs): + super().__init__(module, **kwargs) + + def __getattr__(self, name): + wrapped = super().__getattr__('module') + if hasattr(wrapped, name): + return getattr(wrapped, name) + return super().__getattr__(name) + + +def init_device(device, backend='nccl', host=None, port=None): + os.environ['CUDA_VISIBLE_DEVICES'] = device + if torch.cuda.device_count() > 1: + host = host or os.environ.get('MASTER_ADDR', 'localhost') + port = port or os.environ.get('MASTER_PORT', str(Random(0).randint(10000, 20000))) + os.environ['MASTER_ADDR'] = host + os.environ['MASTER_PORT'] = port + dist.init_process_group(backend) + torch.cuda.set_device(dist.get_rank()) + + +def is_master(): + return not dist.is_initialized() or dist.get_rank() == 0 diff --git a/underthesea/utils/sp_vocab.py b/underthesea/utils/sp_vocab.py new file mode 100644 index 00000000..a814ad5b --- /dev/null +++ b/underthesea/utils/sp_vocab.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +from collections import defaultdict +from collections.abc import Iterable + + +class Vocab(object): + r""" + Defines a vocabulary object that will be used to numericalize a field. + + Args: + counter (~collections.Counter): + :class:`~collections.Counter` object holding the frequencies of each value found in the data. + min_freq (int): + The minimum frequency needed to include a token in the vocabulary. Default: 1. + specials (list[str]): + The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: []. + unk_index (int): + The index of unk token. Default: 0. + + Attributes: + itos: + A list of token strings indexed by their numerical identifiers. + stoi: + A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers. + """ + + def __init__(self, counter, min_freq=1, specials=[], unk_index=0): + self.itos = list(specials) + self.stoi = defaultdict(lambda: unk_index) + self.stoi.update({token: i for i, token in enumerate(self.itos)}) + self.extend([token for token, freq in counter.items() + if freq >= min_freq]) + self.unk_index = unk_index + self.n_init = len(self) + + def __len__(self): + return len(self.itos) + + def __getitem__(self, key): + if isinstance(key, str): + return self.stoi[key] + elif not isinstance(key, Iterable): + return self.itos[key] + elif isinstance(key[0], str): + return [self.stoi[i] for i in key] + else: + return [self.itos[i] for i in key] + + def __contains__(self, token): + return token in self.stoi + + def __getstate__(self): + # avoid picking defaultdict + attrs = dict(self.__dict__) + # cast to regular dict + attrs['stoi'] = dict(self.stoi) + return attrs + + def __setstate__(self, state): + stoi = defaultdict(lambda: self.unk_index) + stoi.update(state['stoi']) + state['stoi'] = stoi + self.__dict__.update(state) + + def extend(self, tokens): + self.itos.extend(sorted(set(tokens).difference(self.stoi))) + self.stoi.update({token: i for i, token in enumerate(self.itos)})