-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNet-1375][Fit API]Added RNN integration test for fit() API #14547
Changes from 1 commit
a25c6b2
ae81b08
a7dfeb3
60ec9b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,6 +106,22 @@ core_logic: { | |
utils.docker_run('ubuntu_nightly_gpu', 'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m') | ||
} | ||
} | ||
}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be in the JenkinsfileForBinaries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done ae81b08. Moved to |
||
'estimator: RNN GPU': { | ||
node(NODE_LINUX_GPU) { | ||
ws('workspace/estimator-test-rnn-gpu') { | ||
utils.unpack_and_init('gpu', mx_lib) | ||
utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator_test_rnn_gpu', true) | ||
} | ||
} | ||
}, | ||
'estimator: RNN CPU': { | ||
node(NODE_LINUX_CPU) { | ||
ws('workspace/estimator-test-rnn-cpu') { | ||
utils.unpack_and_init('cpu', mx_lib) | ||
utils.docker_run('ubuntu_nightly_cpu', 'nightly_estimator_test_rnn_cpu', true) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Gluon Text Sentiment Classification Example using RNN/CNN | ||
Example modified from below link: | ||
https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-rnn.md | ||
https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-cnn.md""" | ||
|
||
import argparse | ||
import os | ||
import tarfile | ||
import random | ||
import collections | ||
import mxnet as mx | ||
from mxnet import nd | ||
from mxnet.contrib import text | ||
from mxnet.gluon import data as gdata, loss as gloss, utils as gutils, nn, rnn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just import data, loss, utils, don't rename it. the book did it for separation of mxnet and gluoncv/nlp, d2l package, we only have mxnet here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Updated it. ae81b08 |
||
from mxnet.gluon.estimator import estimator as est | ||
|
||
|
||
class TextCNN(nn.Block): | ||
def __init__(self, vocab, embed_size, kernel_sizes, num_channels, | ||
**kwargs): | ||
super(TextCNN, self).__init__(**kwargs) | ||
self.embedding = nn.Embedding(len(vocab), embed_size) | ||
# The embedding layer does not participate in training | ||
self.constant_embedding = nn.Embedding(len(vocab), embed_size) | ||
self.dropout = nn.Dropout(0.5) | ||
self.decoder = nn.Dense(2) | ||
# The max-over-time pooling layer has no weight, so it can share an | ||
# instance | ||
self.pool = nn.GlobalMaxPool1D() | ||
# Create multiple one-dimensional convolutional layers | ||
self.convs = nn.Sequential() | ||
for c, k in zip(num_channels, kernel_sizes): | ||
self.convs.add(nn.Conv1D(c, k, activation='relu')) | ||
|
||
def forward(self, inputs): | ||
# Concatenate the output of two embedding layers with shape of | ||
# (batch size, number of words, word vector dimension) by word vector | ||
embeddings = nd.concat( | ||
self.embedding(inputs), self.constant_embedding(inputs), dim=2) | ||
# According to the input format required by Conv1D, the word vector | ||
# dimension, that is, the channel dimension of the one-dimensional | ||
# convolutional layer, is transformed into the previous dimension | ||
embeddings = embeddings.transpose((0, 2, 1)) | ||
# For each one-dimensional convolutional layer, after max-over-time | ||
# pooling, an NDArray with the shape of (batch size, channel size, 1) | ||
# can be obtained. Use the flatten function to remove the last | ||
# dimension and then concatenate on the channel dimension | ||
encoding = nd.concat(*[nd.flatten( | ||
self.pool(conv(embeddings))) for conv in self.convs], dim=1) | ||
# After applying the dropout method, use a fully connected layer to | ||
# obtain the output | ||
outputs = self.decoder(self.dropout(encoding)) | ||
return outputs | ||
|
||
|
||
class BiRNN(nn.Block): | ||
def __init__(self, vocab, embed_size, num_hiddens, num_layers, **kwargs): | ||
super(BiRNN, self).__init__(**kwargs) | ||
self.embedding = nn.Embedding(len(vocab), embed_size) | ||
# Set Bidirectional to True to get a bidirectional recurrent neural | ||
# network | ||
self.encoder = rnn.LSTM(num_hiddens, num_layers=num_layers, | ||
bidirectional=True, input_size=embed_size) | ||
self.decoder = nn.Dense(2) | ||
|
||
def forward(self, inputs): | ||
# The shape of inputs is (batch size, number of words). Because LSTM | ||
# needs to use sequence as the first dimension, the input is | ||
# transformed and the word feature is then extracted. The output shape | ||
# is (number of words, batch size, word vector dimension). | ||
embeddings = self.embedding(inputs.T) | ||
# The shape of states is (number of words, batch size, 2 * number of | ||
# hidden units). | ||
states = self.encoder(embeddings) | ||
# Concatenate the hidden states of the initial time step and final | ||
# time step to use as the input of the fully connected layer. Its | ||
# shape is (batch size, 4 * number of hidden units) | ||
encoding = nd.concat(states[0], states[-1]) | ||
outputs = self.decoder(encoding) | ||
return outputs | ||
|
||
|
||
def download_imdb(data_dir='./data'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you keep this path as /tmp/data ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. Commit ae81b08 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are 5 functions for loading and processing the dataset, can't they be combined into one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keeping it separate makes the code more readable and easy to navigate the flow. |
||
''' | ||
Download and extract the IMDB dataset | ||
''' | ||
url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we keep this data on an S3 bucket instead for more reliability of it being downloaded ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 to comment by @piyushghai. Or maybe look if its there in gluon-nlp? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to handle the licensing part if we add it to our S3 bucket. Let's keep it this way for now. I also found this R example |
||
sha1 = '01ada507287d82875905620988597833ad4e0903' | ||
if not os.path.exists(data_dir): | ||
os.makedirs(data_dir) | ||
file_path = os.path.join(data_dir, 'aclImdb_v1.tar.gz') | ||
if not os.path.isfile(file_path): | ||
file_path = gutils.download(url, data_dir, sha1_hash=sha1) | ||
with tarfile.open(file_path, 'r') as f: | ||
f.extractall(data_dir) | ||
|
||
|
||
def read_imdb(folder='train'): | ||
''' | ||
Read the IMDB dataset | ||
''' | ||
data = [] | ||
for label in ['pos', 'neg']: | ||
folder_name = os.path.join('./data/aclImdb/', folder, label) | ||
for file in os.listdir(folder_name): | ||
with open(os.path.join(folder_name, file), 'rb') as f: | ||
review = f.read().decode('utf-8').replace('\n', '').lower() | ||
data.append([review, 1 if label == 'pos' else 0]) | ||
random.shuffle(data) | ||
return data | ||
|
||
|
||
def get_tokenized_imdb(data): | ||
''' | ||
Tokenized the words | ||
''' | ||
|
||
def tokenizer(text): | ||
return [tok.lower() for tok in text.split(' ')] | ||
|
||
return [tokenizer(review) for review, _ in data] | ||
|
||
|
||
def get_vocab_imdb(data): | ||
''' | ||
Get the indexed tokens | ||
''' | ||
tokenized_data = get_tokenized_imdb(data) | ||
counter = collections.Counter([tk for st in tokenized_data for tk in st]) | ||
return text.vocab.Vocabulary(counter, min_freq=5) | ||
|
||
|
||
def preprocess_imdb(data, vocab): | ||
''' | ||
Make the length of each comment 500 by truncating or adding 0s | ||
''' | ||
max_l = 500 | ||
|
||
def pad(x): | ||
return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x)) | ||
|
||
tokenized_data = get_tokenized_imdb(data) | ||
features = nd.array([pad(vocab.to_indices(x)) for x in tokenized_data]) | ||
labels = nd.array([score for _, score in data]) | ||
return features, labels | ||
|
||
|
||
def test_estimator_cpu(): | ||
''' | ||
Test estimator by doing one pass over each model with synthetic data | ||
''' | ||
models = ['TextCNN', 'BiRNN'] | ||
context = mx.cpu() | ||
batch_size = 64 | ||
num_epochs = 1 | ||
lr = 0.01 | ||
embed_size = 100 | ||
|
||
train_data = mx.nd.random.randint(low=0, high=100, shape=(2 * batch_size, 500)) | ||
train_label = mx.nd.random.randint(low=0, high=2, shape=(2 * batch_size,)) | ||
val_data = mx.nd.random.randint(low=0, high=100, shape=(batch_size, 500)) | ||
val_label = mx.nd.random.randint(low=0, high=2, shape=(batch_size,)) | ||
|
||
train_dataloader = gdata.DataLoader(dataset=gdata.ArrayDataset(train_data, train_label), | ||
batch_size=batch_size, shuffle=True) | ||
val_dataloader = gdata.DataLoader(dataset=gdata.ArrayDataset(val_data, val_label), | ||
batch_size=batch_size) | ||
vocab_list = mx.nd.zeros(shape=(100,)) | ||
|
||
# Get the model | ||
for model in models: | ||
if model == 'TextCNN': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid splitting logic here we can have a get_model() function which returns the net and related items, and use it for both the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a separate model class for both the model and we are instantiating the object based on the selection. The |
||
kernel_sizes, nums_channels = [3, 4, 5], [100, 100, 100] | ||
net = TextCNN(vocab_list, embed_size, kernel_sizes, nums_channels) | ||
else: | ||
num_hiddens, num_layers = 100, 2 | ||
net = BiRNN(vocab_list, embed_size, num_hiddens, num_layers) | ||
net.initialize(mx.init.Xavier(), ctx=context) | ||
# Define trainer | ||
trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr}) | ||
# Define loss and evaluation metrics | ||
loss = gloss.SoftmaxCrossEntropyLoss() | ||
acc = mx.metric.Accuracy() | ||
|
||
# Define estimator | ||
e = est.Estimator(net=net, loss=loss, metrics=acc, | ||
trainers=trainer, context=context) | ||
# Begin training | ||
e.fit(train_data=train_dataloader, val_data=val_dataloader, | ||
epochs=num_epochs, batch_size=batch_size) | ||
|
||
|
||
def test_estimator_gpu(): | ||
''' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a lot of things in common between I see in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified a code to reduce the redundancy. Could you please review it again? Thanks! |
||
Test estimator by training Bidirectional RNN for 5 epochs on the IMDB dataset | ||
and verify accuracy | ||
''' | ||
batch_size = 64 | ||
num_epochs = 5 | ||
lr = 0.01 | ||
embed_size = 100 | ||
|
||
# Set context | ||
if mx.context.num_gpus() > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can write this more succinctly as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes absolutely. Modified it. ae81b08 |
||
ctx = mx.gpu(0) | ||
else: | ||
ctx = mx.cpu() | ||
|
||
# data | ||
download_imdb() | ||
train_data, test_data = read_imdb('train'), read_imdb('test') | ||
vocab = get_vocab_imdb(train_data) | ||
|
||
train_set = gdata.ArrayDataset(*preprocess_imdb(train_data, vocab)) | ||
test_set = gdata.ArrayDataset(*preprocess_imdb(test_data, vocab)) | ||
train_dataloader = gdata.DataLoader(train_set, batch_size, shuffle=True) | ||
test_dataloader = gdata.DataLoader(test_set, batch_size) | ||
|
||
# Model | ||
num_hiddens, num_layers = 100, 2 | ||
net = BiRNN(vocab, embed_size, num_hiddens, num_layers) | ||
net.initialize(mx.init.Xavier(), ctx=ctx) | ||
|
||
glove_embedding = text.embedding.create( | ||
'glove', pretrained_file_name='glove.6B.100d.txt', vocabulary=vocab) | ||
|
||
net.embedding.weight.set_data(glove_embedding.idx_to_vec) | ||
net.embedding.collect_params().setattr('grad_req', 'null') | ||
|
||
# Define Trainer | ||
trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr}) | ||
# Define loss and evaluation metrics | ||
loss = gloss.SoftmaxCrossEntropyLoss() | ||
acc = mx.metric.Accuracy() | ||
|
||
# Define estimator | ||
e = est.Estimator(net=net, loss=loss, metrics=acc, | ||
trainers=trainer, context=ctx) | ||
# Begin training | ||
e.fit(train_data=train_dataloader, val_data=test_dataloader, | ||
epochs=num_epochs) | ||
|
||
assert e.train_stats['train_' + acc.name][num_epochs - 1] > 0.70 | ||
|
||
|
||
if __name__ == '__main__': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need main in integration tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. ae81b08 |
||
parser = argparse.ArgumentParser(description='test gluon estimator') | ||
parser.add_argument('--type', type=str, default='cpu') | ||
opt = parser.parse_args() | ||
if opt.type == 'cpu': | ||
test_estimator_cpu() | ||
elif opt.type == 'gpu': | ||
test_estimator_gpu() | ||
else: | ||
raise RuntimeError("Unknown test type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use nosetest and assert accuracy for gpu, do not run python scripts here