Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Chain CRF layer and a text chunking example #4621

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions examples/conll2000_bi_lstm_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
'''This example demonstrates the use of a bidirectional LSTM and a
Linear Chain Conditional Random Field for text chunking.
(http://www.cnts.ua.ac.be/conll2000/chunking/).

Gets >= 93 FB1 score on test dataset after 3 epochs.

'''
from __future__ import print_function, unicode_literals
import numpy as np
np.random.seed(1337) # for reproducibility

from six.moves import zip
from keras.preprocessing import sequence
from keras.models import Model
from keras.layers import Input, merge
from keras.layers import Dense, Embedding, ChainCRF, LSTM, Bidirectional, Dropout
from keras.layers.wrappers import TimeDistributed
from keras.optimizers import RMSprop
from keras.datasets import conll2000
from keras.utils.data_utils import get_file
from keras.callbacks import Callback
from subprocess import Popen, PIPE, STDOUT


def run_conlleval(X_words_test, y_test, y_pred, index2word, index2chunk, pad_id=0):
'''
Runs the conlleval script for evaluation the predicted IOB-tags.
'''
url = 'http://www.cnts.ua.ac.be/conll2000/chunking/conlleval.txt'
path = get_file('conlleval',
origin=url,
md5_hash='61b632189e5a05d5bd26a2e1ec0f4f9e')

p = Popen(['perl', path], stdout=PIPE, stdin=PIPE, stderr=STDOUT)

y_true = np.squeeze(y_test, axis=2)

sequence_lengths = np.argmax(X_words_test == pad_id, axis=1)
nb_samples = X_words_test.shape[0]
conlleval_input = []
for k in range(nb_samples):
sent_len = sequence_lengths[k]
words = list(map(lambda idx: index2word[idx], X_words_test[k][:sent_len]))
true_tags = list(map(lambda idx: index2chunk[idx], y_true[k][:sent_len]))
pred_tags = list(map(lambda idx: index2chunk[idx], y_pred[k][:sent_len]))
sent = zip(words, true_tags, pred_tags)
for row in sent:
conlleval_input.append(' '.join(row))
conlleval_input.append('')
print()
conlleval_stdout = p.communicate(input='\n'.join(conlleval_input).encode())[0]
print(conlleval_stdout.decode())


class ConllevalCallback(Callback):
'''Callback for running the conlleval script on the test dataset after
each epoch.
'''
def __init__(self, X_test, y_test, batch_size=1, index2word=None, index2chunk=None):
self.X_words_test, self.X_pos_test = X_test
self.y_test = y_test
self.batch_size = batch_size
self.index2word = index2word
self.index2chunk = index2chunk

def on_epoch_end(self, epoch, logs={}):
X_test = [self.X_words_test, self.X_pos_test]
pred_proba = model.predict(X_test)
y_pred = np.argmax(pred_proba, axis=2)
run_conlleval(self.X_words_test, self.y_test, y_pred, self.index2word, self.index2chunk)


maxlen = 80 # cut texts after this number of words (among top max_features most common words)
word_embedding_dim = 100
pos_embedding_dim = 32
lstm_dim = 100
batch_size = 64

print('Loading data...')
(X_words_train, X_pos_train, y_train), (X_words_test, X_pos_test, y_test), (index2word, index2pos, index2chunk) = conll2000.load_data(word_preprocess=lambda w: w.lower())

max_features = len(index2word)
nb_pos_tags = len(index2pos)
nb_chunk_tags = len(index2chunk)

X_words_train = sequence.pad_sequences(X_words_train, maxlen=maxlen, padding='post')
X_pos_train = sequence.pad_sequences(X_pos_train, maxlen=maxlen, padding='post')
X_words_test = sequence.pad_sequences(X_words_test, maxlen=maxlen, padding='post')
X_pos_test = sequence.pad_sequences(X_pos_test, maxlen=maxlen, padding='post')
y_train = sequence.pad_sequences(y_train, maxlen=maxlen, padding='post')
y_train = np.expand_dims(y_train, -1)
y_test = sequence.pad_sequences(y_test, maxlen=maxlen, padding='post')
y_test = np.expand_dims(y_test, -1)

print('Unique words:', max_features)
print('Unique pos_tags:', nb_pos_tags)
print('Unique chunk tags:', nb_chunk_tags)
print('X_words_train shape:', X_words_train.shape)
print('X_words_test shape:', X_words_test.shape)
print('y_train shape:', y_train.shape)
print('y_test shape:', y_test.shape)

print('Build model...')

word_input = Input(shape=(maxlen,), dtype='int32', name='word_input')
word_emb = Embedding(max_features, word_embedding_dim, input_length=maxlen, dropout=0.2, name='word_emb')(word_input)
pos_input = Input(shape=(maxlen,), dtype='int32', name='pos_input')
pos_emb = Embedding(nb_pos_tags, pos_embedding_dim, input_length=maxlen, dropout=0.2, name='pos_emb')(pos_input)
total_emb = merge([word_emb, pos_emb], mode='concat', concat_axis=2)

bilstm = Bidirectional(LSTM(lstm_dim, dropout_W=0.2, dropout_U=0.2, return_sequences=True))(total_emb)
bilstm_d = Dropout(0.2)(bilstm)
dense = TimeDistributed(Dense(nb_chunk_tags))(bilstm_d)

crf = ChainCRF()
crf_output = crf(dense)

model = Model(input=[word_input, pos_input], output=[crf_output])

model.compile(loss=crf.sparse_loss,
optimizer=RMSprop(0.01),
metrics=['sparse_categorical_accuracy'])

model.summary()


conlleval = ConllevalCallback([X_words_test, X_pos_test], y_test,
index2word=index2word, index2chunk=index2chunk,
batch_size=batch_size)
print('Train...')
model.fit([X_words_train, X_pos_train], y_train,
batch_size=batch_size, nb_epoch=3, callbacks=[conlleval])
26 changes: 26 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3089,6 +3089,32 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100,
return (decoded_dense, log_prob)


def logsumexp(x, axis=None):
'''Returns `log(sum(exp(x), axis=axis))` with improved numerical stability.
'''
return tf.reduce_logsumexp(x, axis=[axis])


def batch_gather(reference, indices):
'''Batchwise gathering of row indices.

The numpy equivalent is reference[np.arange(batch_size), indices].

# Arguments
reference: tensor with ndim >= 2 of shape
(batch_size, dim1, dim2, ..., dimN)
indices: 1d integer tensor of shape (batch_size) satisfiying
0 <= i < dim2 for each element i.

# Returns
A tensor with shape (batch_size, dim2, ..., dimN)
equal to reference[1:batch_size, indices]
'''
batch_size = shape(reference)[0]
indices = tf.pack([tf.range(batch_size), indices], axis=1)
return tf.gather_nd(reference, indices)


# HIGH ORDER FUNCTIONS

def map_fn(fn, elems, name=None):
Expand Down
27 changes: 27 additions & 0 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,33 @@ def ctc_step(y_true_step, y_pred_step, input_length_step, label_length_step):
return ret


def logsumexp(x, axis=None):
'''Returns `log(sum(exp(x), axis=axis))` with improved numerical stability.
'''
xmax = max(x, axis=axis, keepdims=True)
xmax_ = max(x, axis=axis)
return xmax_ + log(sum(exp(x - xmax), axis=axis))


def batch_gather(reference, indices):
'''Batchwise gathering of row indices.

The numpy equivalent is reference[np.arange(batch_size), indices],

# Arguments
reference: tensor with ndim >= 2 of shape
(batch_size, dim1, dim2, ..., dimN)
indices: 1d integer tensor of shape (batch_size) satisfiying
0 <= i < dim2 for each element i.

# Returns
A tensor with shape (batch_size, dim2, ..., dimN)
equal to reference[1:batch_size, indices]
'''
batch_size = shape(reference)[0]
return reference[T.arange(batch_size), indices]


# HIGH ORDER FUNCTIONS

def map_fn(fn, elems, name=None):
Expand Down
163 changes: 163 additions & 0 deletions keras/datasets/conll2000.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
'''Text chunking dataset for testing sequence labeling architectures.

Source: http://www.cnts.ua.ac.be/conll2000/chunking
'''
from __future__ import absolute_import, unicode_literals
from six.moves import cPickle
import gzip
from ..utils.data_utils import get_file
from six.moves import zip
import numpy as np
import sys
import os
from collections import Counter
from itertools import chain


CHUNK_TAGS = [
'PAD',
'B-ADJP',
'B-ADVP',
'B-CONJP',
'B-INTJ',
'B-LST',
'B-NP',
'B-PP',
'B-PRT',
'B-SBAR',
'B-UCP',
'B-VP',
'I-ADJP',
'I-ADVP',
'I-CONJP',
'I-INTJ',
'I-LST',
'I-NP',
'I-PP',
'I-PRT',
'I-SBAR',
'I-UCP',
'I-VP',
'O'
]


POS_TAGS = [
'<PAD>',
'#',
'$',
"''",
'(',
')',
',',
'.',
':',
'CC',
'CD',
'DT',
'EX',
'FW',
'IN',
'JJ',
'JJR',
'JJS',
'MD',
'NN',
'NNP',
'NNPS',
'NNS',
'PDT',
'POS',
'PRP',
'PRP$',
'RB',
'RBR',
'RBS',
'RP',
'SYM',
'TO',
'UH',
'VB',
'VBD',
'VBG',
'VBN',
'VBP',
'VBZ',
'WDT',
'WP',
'WP$',
'WRB',
'``'
]


def load_data(word_preprocess=lambda x: x):
'''Loads the conll2000 text chunking dataset.

# Arguments:
word_preprocess: A lambda expression used for filtering the word forms.
For example, use `lambda w: w.lower()` when all words should be
lowercased.
'''
X_words_train, X_pos_train, y_train = load_file('train.txt.gz', md5_hash='6969c2903a1f19a83569db643e43dcc8')
X_words_test, X_pos_test, y_test = load_file('test.txt.gz', md5_hash='a916e1c2d83eb3004b38fc6fcd628939')

index2word = _fit_term_index(X_words_train, reserved=['<PAD>', '<UNK>'], preprocess=word_preprocess)
word2index = _invert_index(index2word)

index2pos = POS_TAGS
pos2index = _invert_index(index2pos)

index2chunk = CHUNK_TAGS
chunk2index = _invert_index(index2chunk)

X_words_train = np.array([[word2index[word_preprocess(w)] for w in words] for words in X_words_train])
X_pos_train = np.array([[pos2index[t] for t in pos_tags] for pos_tags in X_pos_train])
y_train = np.array([[chunk2index[t] for t in chunk_tags] for chunk_tags in y_train])
X_words_test = np.array([[word2index.get(word_preprocess(w), word2index['<UNK>']) for w in words] for words in X_words_test])
X_pos_test = np.array([[pos2index[t] for t in pos_tags] for pos_tags in X_pos_test])
y_test = np.array([[chunk2index[t] for t in chunk_tags] for chunk_tags in y_test])
return (X_words_train, X_pos_train, y_train), (X_words_test, X_pos_test, y_test), (index2word, index2pos, index2chunk)


def _fit_term_index(terms, reserved=[], preprocess=lambda x: x):
all_terms = chain(*terms)
all_terms = map(preprocess, all_terms)
term_freqs = Counter(all_terms).most_common()
id2term = reserved + [term for term, tf in term_freqs]
return id2term


def _invert_index(id2term):
return {term: i for i, term in enumerate(id2term)}


def load_file(filename, md5_hash):
'''Loads and parses a conll2000 data file.

# Arguments:
filename: The requested filename.
md5_hash: The expected md5 hash.
'''
path = get_file('conll2000_' + filename,
origin='http://www.cnts.ua.ac.be/conll2000/chunking/' + filename,
md5_hash=md5_hash)
with gzip.open(path, 'rt') as fd:
rows = _parse_grid_iter(fd)
words, pos_tags, chunk_tags = zip(*[zip(*row) for row in rows])
return words, pos_tags, chunk_tags


def _parse_grid_iter(fd, sep=' '):
'''
Yields the parsed sentences for a given file descriptor
'''
sentence = []
for line in fd:
if line == '\n' and len(sentence) > 0:
yield sentence
sentence = []
else:
sentence.append(line.strip().split(sep))
if len(sentence) > 0:
yield sentence
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import
from ..engine import Layer, Input, InputLayer, Merge, merge, InputSpec
from .crf import *
from .core import *
from .convolutional import *
from .pooling import *
Expand Down
Loading