Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
kaishengtai committed Apr 2, 2015
0 parents commit f203c91
Show file tree
Hide file tree
Showing 23 changed files with 2,710 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.DS_Store
data
predictions
trained_models
*~
#*#
340 changes: 340 additions & 0 deletions LICENSE.txt

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
Tree-Structured Long Short-Term Memory Networks
===============================================

An implementation of the Tree-LSTM architectures described in the paper
[Improved Semantic Representations From Tree-Structured Long Short-Term Memory
Networks](http://arxiv.org/abs/1503.00075) by Kai Sheng Tai, Richard Socher, and
Christopher Manning.

## Requirements

- [Torch7](https://github.com/torch/torch7)
- [penlight](https://github.com/stevedonovan/Penlight)
- [nn](https://github.com/torch/nn)
- [nngraph](https://github.com/torch/nngraph)
- [optim](https://github.com/torch/optim)
- Java >= 8 (for Stanford CoreNLP utilities)
- Python >= 2.7

The Torch/Lua dependencies can be installed using [luarocks](http://luarocks.org). For example:

```
luarocks install nngraph
```

## Usage

First run the following script:

```
./fetch_and_preprocess.sh
```

This downloads the following data:

- [SICK dataset](http://alt.qcri.org/semeval2014/task1/index.php?id=data-and-tools) (semantic relatedness task)
- [Stanford Sentiment Treebank](http://nlp.stanford.edu/sentiment/index.html) (sentiment classification task)
- [Glove word vectors](http://nlp.stanford.edu/projects/glove/) (Common Crawl 840B) -- **Warning:** this is a 2GB download!

and the following libraries:

- [Stanford Parser](http://nlp.stanford.edu/software/lex-parser.shtml)
- [Stanford POS Tagger](http://nlp.stanford.edu/software/tagger.shtml)

The preprocessing script generates dependency parses of the SICK dataset using the
[Stanford Neural Network Dependency Parser](http://nlp.stanford.edu/software/nndep.shtml).

Alternatively, the download and preprocessing scripts can be called individually.

**For the semantic relatedness task, run:**

```
th relatedness/main.lua
```

**For the sentiment classification task, run:**

```
th sentiment/main.lua
```

This trains a model for the "fine-grained" 5-class classification sub-task.

For the binary classification sub-task, run:

```
th sentiment/main.lua --binary
```

Predictions are written to the `predictions` directory and trained model parameters are saved to the `trained_models` directory.

See the [paper](http://arxiv.org/abs/1503.00075) for details on these experiments.
12 changes: 12 additions & 0 deletions fetch_and_preprocess.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
python2.7 scripts/download.py
python2.7 scripts/preprocess-sick.py
python2.7 scripts/preprocess-sst.py

glove_dir="data/glove"
glove_pre="glove.840B"
glove_dim="300d"
if [ ! -f $glove_dir/$glove_pre.$glove_dim.th ]; then
th scripts/convert-wordvecs.lua $glove_dir/$glove_pre.$glove_dim.txt \
$glove_dir/$glove_pre.vocab $glove_dir/$glove_pre.$glove_dim.th
fi
42 changes: 42 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
require('torch')
require('nn')
require('nngraph')
require('optim')
require('xlua')
require('sys')

treelstm = {}

include('util/read_data.lua')
include('util/Tree.lua')
include('util/Vocab.lua')
include('layers/CRowAddTable.lua')
include('models/LSTM.lua')
include('models/TreeLSTM.lua')
include('models/ChildSumTreeLSTM.lua')
include('models/BinaryTreeLSTM.lua')
include('relatedness/TreeLSTMSim.lua')
include('sentiment/TreeLSTMSentiment.lua')

printf = utils.printf

-- global paths -- modify if desired
treelstm.data_dir = 'data'
treelstm.models_dir = 'trained_models'
treelstm.predictions_dir = 'predictions'

-- share parameters of nngraph gModule instances
function share_params(cell, src, ...)
for i = 1, #cell.forwardnodes do
local node = cell.forwardnodes[i]
if node.data.module then
node.data.module:share(src.forwardnodes[i].data.module, ...)
end
end
end

function header(s)
print(string.rep('-', 80))
print(s)
print(string.rep('-', 80))
end
38 changes: 38 additions & 0 deletions layers/CRowAddTable.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
--[[
Add a vector to every row of a matrix.
Input: { [n x m], [m] }
Output: [n x m]
--]]

local CRowAddTable, parent = torch.class('treelstm.CRowAddTable', 'nn.Module')

function CRowAddTable:__init()
parent.__init(self)
self.gradInput = {}
end

function CRowAddTable:updateOutput(input)
self.output:resizeAs(input[1]):copy(input[1])
for i = 1, self.output:size(1) do
self.output[i]:add(input[2])
end
return self.output
end

function CRowAddTable:updateGradInput(input, gradOutput)
self.gradInput[1] = self.gradInput[1] or input[1].new()
self.gradInput[2] = self.gradInput[2] or input[2].new()
self.gradInput[1]:resizeAs(input[1])
self.gradInput[2]:resizeAs(input[2]):zero()

self.gradInput[1]:copy(gradOutput)
for i = 1, gradOutput:size(1) do
self.gradInput[2]:add(gradOutput[i])
end

return self.gradInput
end
34 changes: 34 additions & 0 deletions lib/CollapseUnaryTransformer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import java.util.List;

import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.util.Generics;

/**
* This transformer collapses chains of unary nodes so that the top
* node is the only node left. The Sentiment model does not handle
* unary nodes, so this simplifies them to make a binary tree consist
* entirely of binary nodes and preterminals. A new tree with new
* nodes and labels is returned; the original tree is unchanged.
*
* @author John Bauer
*/
public class CollapseUnaryTransformer implements TreeTransformer {
public Tree transformTree(Tree tree) {
if (tree.isPreTerminal() || tree.isLeaf()) {
return tree.deepCopy();
}

Label label = tree.label().labelFactory().newLabel(tree.label());
Tree[] children = tree.children();
while (children.length == 1 && !children[0].isLeaf()) {
children = children[0].children();
}
List<Tree> processedChildren = Generics.newArrayList();
for (Tree child : children) {
processedChildren.add(transformTree(child));
}
return tree.treeFactory().newTreeNode(label, processedChildren);
}
}
150 changes: 150 additions & 0 deletions lib/ConstituencyParse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import edu.stanford.nlp.process.WordTokenFactory;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.process.PTBTokenizer;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.TreeBinarizer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Trees;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.HashMap;
import java.util.Properties;
import java.util.Scanner;

public class ConstituencyParse {

public static void main(String[] args) throws Exception {
Properties props = StringUtils.argsToProperties(args);
if (!props.containsKey("tokpath") ||
!props.containsKey("parentpath")) {
System.err.println(
"usage: java ConstituencyParse -tokenize - -tokpath <tokpath> -parentpath <parentpath>");
System.exit(1);
}

boolean tokenize = false;
if (props.containsKey("tokenize")) {
tokenize = true;
}

String tokPath = props.getProperty("tokpath");
String parentPath = props.getProperty("parentpath");

BufferedWriter tokWriter = new BufferedWriter(new FileWriter(tokPath));
BufferedWriter parentWriter = new BufferedWriter(new FileWriter(parentPath));

LexicalizedParser parser = LexicalizedParser.loadModel(
"edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz");
TreeBinarizer binarizer = TreeBinarizer.simpleTreeBinarizer(
parser.getTLPParams().headFinder(), parser.treebankLanguagePack());
CollapseUnaryTransformer transformer = new CollapseUnaryTransformer();

Scanner stdin = new Scanner(System.in);
int count = 0;
long start = System.currentTimeMillis();
while (stdin.hasNextLine()) {
String line = stdin.nextLine();
List<HasWord> tokens = new ArrayList<>();
if (tokenize) {
PTBTokenizer<Word> tokenizer = new PTBTokenizer(
new StringReader(line), new WordTokenFactory(), "");
for (Word label; tokenizer.hasNext(); ) {
tokens.add(tokenizer.next());
}
} else {
for (String word : line.split(" ")) {
tokens.add(new Word(word));
}
}

Tree tree = parser.apply(tokens);
Tree binarized = binarizer.transformTree(tree);
Tree collapsedUnary = transformer.transformTree(binarized);
Trees.convertToCoreLabels(collapsedUnary);
collapsedUnary.indexSpans();

List<Tree> leaves = collapsedUnary.getLeaves();
int size = collapsedUnary.size() - leaves.size();
int[] parents = new int[size];
HashMap<Integer, Integer> index = new HashMap<Integer, Integer>();

int idx = leaves.size();
int leafIdx = 0;
for (Tree leaf : leaves) {
Tree cur = leaf.parent(collapsedUnary); // go to preterminal
int curIdx = leafIdx++;
boolean done = false;
while (!done) {
Tree parent = cur.parent(collapsedUnary);
if (parent == null) {
parents[curIdx] = 0;
break;
}

int parentIdx;
int parentNumber = parent.nodeNumber(collapsedUnary);
if (!index.containsKey(parentNumber)) {
parentIdx = idx++;
index.put(parentNumber, parentIdx);
} else {
parentIdx = index.get(parentNumber);
done = true;
}

parents[curIdx] = parentIdx + 1;
cur = parent;
curIdx = parentIdx;
}
}

// print tokens
int len = tokens.size();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < len - 1; i++) {
if (tokenize) {
sb.append(PTBTokenizer.ptbToken2Text(tokens.get(i).word()));
} else {
sb.append(tokens.get(i).word());
}
sb.append(' ');
}
if (tokenize) {
sb.append(PTBTokenizer.ptbToken2Text(tokens.get(len - 1).word()));
} else {
sb.append(tokens.get(len - 1).word());
}
sb.append('\n');
tokWriter.write(sb.toString());

// print parent pointers
sb = new StringBuilder();
for (int i = 0; i < size - 1; i++) {
sb.append(parents[i]);
sb.append(' ');
}
sb.append(parents[size - 1]);
sb.append('\n');
parentWriter.write(sb.toString());

count++;
if (count % 1000 == 0) {
double elapsed = (System.currentTimeMillis() - start) / 1000.0;
System.err.printf("Parsed %d lines (%.2fs)\n", count, elapsed);
}
}

long totalTimeMillis = System.currentTimeMillis() - start;
System.err.printf("Done: %d lines in %.2fs (%.1fms per line)\n",
count, totalTimeMillis / 1000.0, totalTimeMillis / (double) count);
tokWriter.close();
parentWriter.close();
}
}
Loading

0 comments on commit f203c91

Please sign in to comment.