-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f203c91
Showing
23 changed files
with
2,710 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.DS_Store | ||
data | ||
predictions | ||
trained_models | ||
*~ | ||
#*# |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
Tree-Structured Long Short-Term Memory Networks | ||
=============================================== | ||
|
||
An implementation of the Tree-LSTM architectures described in the paper | ||
[Improved Semantic Representations From Tree-Structured Long Short-Term Memory | ||
Networks](http://arxiv.org/abs/1503.00075) by Kai Sheng Tai, Richard Socher, and | ||
Christopher Manning. | ||
|
||
## Requirements | ||
|
||
- [Torch7](https://github.com/torch/torch7) | ||
- [penlight](https://github.com/stevedonovan/Penlight) | ||
- [nn](https://github.com/torch/nn) | ||
- [nngraph](https://github.com/torch/nngraph) | ||
- [optim](https://github.com/torch/optim) | ||
- Java >= 8 (for Stanford CoreNLP utilities) | ||
- Python >= 2.7 | ||
|
||
The Torch/Lua dependencies can be installed using [luarocks](http://luarocks.org). For example: | ||
|
||
``` | ||
luarocks install nngraph | ||
``` | ||
|
||
## Usage | ||
|
||
First run the following script: | ||
|
||
``` | ||
./fetch_and_preprocess.sh | ||
``` | ||
|
||
This downloads the following data: | ||
|
||
- [SICK dataset](http://alt.qcri.org/semeval2014/task1/index.php?id=data-and-tools) (semantic relatedness task) | ||
- [Stanford Sentiment Treebank](http://nlp.stanford.edu/sentiment/index.html) (sentiment classification task) | ||
- [Glove word vectors](http://nlp.stanford.edu/projects/glove/) (Common Crawl 840B) -- **Warning:** this is a 2GB download! | ||
|
||
and the following libraries: | ||
|
||
- [Stanford Parser](http://nlp.stanford.edu/software/lex-parser.shtml) | ||
- [Stanford POS Tagger](http://nlp.stanford.edu/software/tagger.shtml) | ||
|
||
The preprocessing script generates dependency parses of the SICK dataset using the | ||
[Stanford Neural Network Dependency Parser](http://nlp.stanford.edu/software/nndep.shtml). | ||
|
||
Alternatively, the download and preprocessing scripts can be called individually. | ||
|
||
**For the semantic relatedness task, run:** | ||
|
||
``` | ||
th relatedness/main.lua | ||
``` | ||
|
||
**For the sentiment classification task, run:** | ||
|
||
``` | ||
th sentiment/main.lua | ||
``` | ||
|
||
This trains a model for the "fine-grained" 5-class classification sub-task. | ||
|
||
For the binary classification sub-task, run: | ||
|
||
``` | ||
th sentiment/main.lua --binary | ||
``` | ||
|
||
Predictions are written to the `predictions` directory and trained model parameters are saved to the `trained_models` directory. | ||
|
||
See the [paper](http://arxiv.org/abs/1503.00075) for details on these experiments. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/bin/bash | ||
python2.7 scripts/download.py | ||
python2.7 scripts/preprocess-sick.py | ||
python2.7 scripts/preprocess-sst.py | ||
|
||
glove_dir="data/glove" | ||
glove_pre="glove.840B" | ||
glove_dim="300d" | ||
if [ ! -f $glove_dir/$glove_pre.$glove_dim.th ]; then | ||
th scripts/convert-wordvecs.lua $glove_dir/$glove_pre.$glove_dim.txt \ | ||
$glove_dir/$glove_pre.vocab $glove_dir/$glove_pre.$glove_dim.th | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
require('torch') | ||
require('nn') | ||
require('nngraph') | ||
require('optim') | ||
require('xlua') | ||
require('sys') | ||
|
||
treelstm = {} | ||
|
||
include('util/read_data.lua') | ||
include('util/Tree.lua') | ||
include('util/Vocab.lua') | ||
include('layers/CRowAddTable.lua') | ||
include('models/LSTM.lua') | ||
include('models/TreeLSTM.lua') | ||
include('models/ChildSumTreeLSTM.lua') | ||
include('models/BinaryTreeLSTM.lua') | ||
include('relatedness/TreeLSTMSim.lua') | ||
include('sentiment/TreeLSTMSentiment.lua') | ||
|
||
printf = utils.printf | ||
|
||
-- global paths -- modify if desired | ||
treelstm.data_dir = 'data' | ||
treelstm.models_dir = 'trained_models' | ||
treelstm.predictions_dir = 'predictions' | ||
|
||
-- share parameters of nngraph gModule instances | ||
function share_params(cell, src, ...) | ||
for i = 1, #cell.forwardnodes do | ||
local node = cell.forwardnodes[i] | ||
if node.data.module then | ||
node.data.module:share(src.forwardnodes[i].data.module, ...) | ||
end | ||
end | ||
end | ||
|
||
function header(s) | ||
print(string.rep('-', 80)) | ||
print(s) | ||
print(string.rep('-', 80)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
--[[ | ||
Add a vector to every row of a matrix. | ||
Input: { [n x m], [m] } | ||
Output: [n x m] | ||
--]] | ||
|
||
local CRowAddTable, parent = torch.class('treelstm.CRowAddTable', 'nn.Module') | ||
|
||
function CRowAddTable:__init() | ||
parent.__init(self) | ||
self.gradInput = {} | ||
end | ||
|
||
function CRowAddTable:updateOutput(input) | ||
self.output:resizeAs(input[1]):copy(input[1]) | ||
for i = 1, self.output:size(1) do | ||
self.output[i]:add(input[2]) | ||
end | ||
return self.output | ||
end | ||
|
||
function CRowAddTable:updateGradInput(input, gradOutput) | ||
self.gradInput[1] = self.gradInput[1] or input[1].new() | ||
self.gradInput[2] = self.gradInput[2] or input[2].new() | ||
self.gradInput[1]:resizeAs(input[1]) | ||
self.gradInput[2]:resizeAs(input[2]):zero() | ||
|
||
self.gradInput[1]:copy(gradOutput) | ||
for i = 1, gradOutput:size(1) do | ||
self.gradInput[2]:add(gradOutput[i]) | ||
end | ||
|
||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import java.util.List; | ||
|
||
import edu.stanford.nlp.ling.Label; | ||
import edu.stanford.nlp.trees.Tree; | ||
import edu.stanford.nlp.trees.TreeTransformer; | ||
import edu.stanford.nlp.util.Generics; | ||
|
||
/** | ||
* This transformer collapses chains of unary nodes so that the top | ||
* node is the only node left. The Sentiment model does not handle | ||
* unary nodes, so this simplifies them to make a binary tree consist | ||
* entirely of binary nodes and preterminals. A new tree with new | ||
* nodes and labels is returned; the original tree is unchanged. | ||
* | ||
* @author John Bauer | ||
*/ | ||
public class CollapseUnaryTransformer implements TreeTransformer { | ||
public Tree transformTree(Tree tree) { | ||
if (tree.isPreTerminal() || tree.isLeaf()) { | ||
return tree.deepCopy(); | ||
} | ||
|
||
Label label = tree.label().labelFactory().newLabel(tree.label()); | ||
Tree[] children = tree.children(); | ||
while (children.length == 1 && !children[0].isLeaf()) { | ||
children = children[0].children(); | ||
} | ||
List<Tree> processedChildren = Generics.newArrayList(); | ||
for (Tree child : children) { | ||
processedChildren.add(transformTree(child)); | ||
} | ||
return tree.treeFactory().newTreeNode(label, processedChildren); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import edu.stanford.nlp.process.WordTokenFactory; | ||
import edu.stanford.nlp.ling.HasWord; | ||
import edu.stanford.nlp.ling.Word; | ||
import edu.stanford.nlp.ling.CoreLabel; | ||
import edu.stanford.nlp.process.PTBTokenizer; | ||
import edu.stanford.nlp.util.StringUtils; | ||
import edu.stanford.nlp.parser.lexparser.LexicalizedParser; | ||
import edu.stanford.nlp.parser.lexparser.TreeBinarizer; | ||
import edu.stanford.nlp.trees.Tree; | ||
import edu.stanford.nlp.trees.Trees; | ||
|
||
import java.io.BufferedWriter; | ||
import java.io.FileWriter; | ||
import java.io.StringReader; | ||
import java.util.ArrayList; | ||
import java.util.Collection; | ||
import java.util.List; | ||
import java.util.HashMap; | ||
import java.util.Properties; | ||
import java.util.Scanner; | ||
|
||
public class ConstituencyParse { | ||
|
||
public static void main(String[] args) throws Exception { | ||
Properties props = StringUtils.argsToProperties(args); | ||
if (!props.containsKey("tokpath") || | ||
!props.containsKey("parentpath")) { | ||
System.err.println( | ||
"usage: java ConstituencyParse -tokenize - -tokpath <tokpath> -parentpath <parentpath>"); | ||
System.exit(1); | ||
} | ||
|
||
boolean tokenize = false; | ||
if (props.containsKey("tokenize")) { | ||
tokenize = true; | ||
} | ||
|
||
String tokPath = props.getProperty("tokpath"); | ||
String parentPath = props.getProperty("parentpath"); | ||
|
||
BufferedWriter tokWriter = new BufferedWriter(new FileWriter(tokPath)); | ||
BufferedWriter parentWriter = new BufferedWriter(new FileWriter(parentPath)); | ||
|
||
LexicalizedParser parser = LexicalizedParser.loadModel( | ||
"edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz"); | ||
TreeBinarizer binarizer = TreeBinarizer.simpleTreeBinarizer( | ||
parser.getTLPParams().headFinder(), parser.treebankLanguagePack()); | ||
CollapseUnaryTransformer transformer = new CollapseUnaryTransformer(); | ||
|
||
Scanner stdin = new Scanner(System.in); | ||
int count = 0; | ||
long start = System.currentTimeMillis(); | ||
while (stdin.hasNextLine()) { | ||
String line = stdin.nextLine(); | ||
List<HasWord> tokens = new ArrayList<>(); | ||
if (tokenize) { | ||
PTBTokenizer<Word> tokenizer = new PTBTokenizer( | ||
new StringReader(line), new WordTokenFactory(), ""); | ||
for (Word label; tokenizer.hasNext(); ) { | ||
tokens.add(tokenizer.next()); | ||
} | ||
} else { | ||
for (String word : line.split(" ")) { | ||
tokens.add(new Word(word)); | ||
} | ||
} | ||
|
||
Tree tree = parser.apply(tokens); | ||
Tree binarized = binarizer.transformTree(tree); | ||
Tree collapsedUnary = transformer.transformTree(binarized); | ||
Trees.convertToCoreLabels(collapsedUnary); | ||
collapsedUnary.indexSpans(); | ||
|
||
List<Tree> leaves = collapsedUnary.getLeaves(); | ||
int size = collapsedUnary.size() - leaves.size(); | ||
int[] parents = new int[size]; | ||
HashMap<Integer, Integer> index = new HashMap<Integer, Integer>(); | ||
|
||
int idx = leaves.size(); | ||
int leafIdx = 0; | ||
for (Tree leaf : leaves) { | ||
Tree cur = leaf.parent(collapsedUnary); // go to preterminal | ||
int curIdx = leafIdx++; | ||
boolean done = false; | ||
while (!done) { | ||
Tree parent = cur.parent(collapsedUnary); | ||
if (parent == null) { | ||
parents[curIdx] = 0; | ||
break; | ||
} | ||
|
||
int parentIdx; | ||
int parentNumber = parent.nodeNumber(collapsedUnary); | ||
if (!index.containsKey(parentNumber)) { | ||
parentIdx = idx++; | ||
index.put(parentNumber, parentIdx); | ||
} else { | ||
parentIdx = index.get(parentNumber); | ||
done = true; | ||
} | ||
|
||
parents[curIdx] = parentIdx + 1; | ||
cur = parent; | ||
curIdx = parentIdx; | ||
} | ||
} | ||
|
||
// print tokens | ||
int len = tokens.size(); | ||
StringBuilder sb = new StringBuilder(); | ||
for (int i = 0; i < len - 1; i++) { | ||
if (tokenize) { | ||
sb.append(PTBTokenizer.ptbToken2Text(tokens.get(i).word())); | ||
} else { | ||
sb.append(tokens.get(i).word()); | ||
} | ||
sb.append(' '); | ||
} | ||
if (tokenize) { | ||
sb.append(PTBTokenizer.ptbToken2Text(tokens.get(len - 1).word())); | ||
} else { | ||
sb.append(tokens.get(len - 1).word()); | ||
} | ||
sb.append('\n'); | ||
tokWriter.write(sb.toString()); | ||
|
||
// print parent pointers | ||
sb = new StringBuilder(); | ||
for (int i = 0; i < size - 1; i++) { | ||
sb.append(parents[i]); | ||
sb.append(' '); | ||
} | ||
sb.append(parents[size - 1]); | ||
sb.append('\n'); | ||
parentWriter.write(sb.toString()); | ||
|
||
count++; | ||
if (count % 1000 == 0) { | ||
double elapsed = (System.currentTimeMillis() - start) / 1000.0; | ||
System.err.printf("Parsed %d lines (%.2fs)\n", count, elapsed); | ||
} | ||
} | ||
|
||
long totalTimeMillis = System.currentTimeMillis() - start; | ||
System.err.printf("Done: %d lines in %.2fs (%.1fms per line)\n", | ||
count, totalTimeMillis / 1000.0, totalTimeMillis / (double) count); | ||
tokWriter.close(); | ||
parentWriter.close(); | ||
} | ||
} |
Oops, something went wrong.