diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e6b17880e27..bbbb51dfb248 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -186,7 +186,7 @@ endif() list(APPEND mxnet_LINKER_LIBS ${mshadow_LINKER_LIBS}) foreach(var ${C_CXX_INCLUDE_DIRECTORIES}) - include_directories(${var}) + include_directories(${var}) endforeach() include_directories("include") @@ -201,9 +201,13 @@ include_directories("dlpack/include") # add_subdirectory(dlpack) #endif() +# Prevent stripping out symbols (operator registrations, for example) if(NOT MSVC AND NOT APPLE) set(BEGIN_WHOLE_ARCHIVE -Wl,--whole-archive) set(END_WHOLE_ARCHIVE -Wl,--no-whole-archive) +elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # using regular Clang or AppleClang + set(BEGIN_WHOLE_ARCHIVE -Wl,-force_load) endif() if(UNIX) @@ -332,7 +336,7 @@ if(USE_CUDNN AND USE_CUDA) add_definitions(-DUSE_CUDNN) include_directories(SYSTEM ${CUDNN_INCLUDE}) list(APPEND mxnet_LINKER_LIBS ${CUDNN_LIBRARY}) - add_definitions(-DMSHADOW_USE_CUDNN=1) + add_definitions(-DMSHADOW_USE_CUDNN=1) endif() endif() @@ -372,17 +376,17 @@ assign_source_group("Include" ${GROUP_Include}) assign_source_group("CUDA" ${GROUP_CUDA}) if(USE_PLUGINS_WARPCTC) - set(WARPCTC_INCLUDE "" CACHE PATH "WARPCTC include") + set(WARPCTC_INCLUDE "" CACHE PATH "WARPCTC include") set(WARPCTC_LIB_DEBUG "" CACHE FILEPATH "WARPCTC lib") set(WARPCTC_LIB_RELEASE "" CACHE FILEPATH "WARPCTC lib") - include_directories(SYSTEM ${WARPCTC_INCLUDE}) - list(APPEND mxnet_LINKER_LIBS ${WARPCTC_LIB}) - FILE(GLOB_RECURSE PLUGINS_SOURCE "plugin/warpctc/*.cc" "plugin/warpctc/*.h") - FILE(GLOB_RECURSE PLUGINS_CUSRC "plugin/warpctc/*.cu") - list(APPEND SOURCE ${PLUGINS_SOURCE}) - list(APPEND CUDA ${PLUGINS_CUSRC}) + include_directories(SYSTEM ${WARPCTC_INCLUDE}) + list(APPEND mxnet_LINKER_LIBS ${WARPCTC_LIB}) + FILE(GLOB_RECURSE PLUGINS_SOURCE "plugin/warpctc/*.cc" "plugin/warpctc/*.h") + FILE(GLOB_RECURSE PLUGINS_CUSRC "plugin/warpctc/*.cu") + list(APPEND SOURCE ${PLUGINS_SOURCE}) + list(APPEND CUDA ${PLUGINS_CUSRC}) endif() if(USE_OPERATOR_TUNING) @@ -425,11 +429,11 @@ if(USE_PLUGIN_CAFFE) endif() if (NOT (EXTRA_OPERATORS STREQUAL "")) - mxnet_source_group("Extra" GLOB_RECURSE "${EXTRA_OPERATORS}/*.cc") - mxnet_source_group("Extra\\Cuda" GLOB_RECURSE "${EXTRA_OPERATORS}/*.cu") - FILE(GLOB_RECURSE EXTRA_SRC "${EXTRA_OPERATORS}/*.cc") - FILE(GLOB_RECURSE EXTRA_CUSRC "${EXTRA_OPERATORS}/*.cu") - list(APPEND SOURCE ${EXTRA_SRC} ${EXTRA_CUSRC}) + mxnet_source_group("Extra" GLOB_RECURSE "${EXTRA_OPERATORS}/*.cc") + mxnet_source_group("Extra\\Cuda" GLOB_RECURSE "${EXTRA_OPERATORS}/*.cu") + FILE(GLOB_RECURSE EXTRA_SRC "${EXTRA_OPERATORS}/*.cc") + FILE(GLOB_RECURSE EXTRA_CUSRC "${EXTRA_OPERATORS}/*.cu") + list(APPEND SOURCE ${EXTRA_SRC} ${EXTRA_CUSRC}) endif() if(MSVC) @@ -567,7 +571,7 @@ if(MSVC AND USE_MXNET_LIB_NAMING) endif() if(USE_PROFILER) - add_definitions(-DMXNET_USE_PROFILER) + add_definitions(-DMXNET_USE_PROFILER) endif() add_subdirectory(tests) @@ -585,7 +589,7 @@ if (INSTALL_EXAMPLES) endif() if (USE_SIGNAL_HANDLER) - add_definitions(-DMXNET_USE_SIGNAL_HANDLER=1) + add_definitions(-DMXNET_USE_SIGNAL_HANDLER=1) endif() # AUTO_INSTALL_DIR -> Optional: specify post-build install direcory diff --git a/Jenkinsfile b/Jenkinsfile index 5b5a2f3d8b43..4fc12f3dab6f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -169,6 +169,42 @@ try { } } }, + 'CPU: Clang 3.9': { + node('mxnetlinux-cpu') { + ws('workspace/build-cpu-clang') { + init_git() + def flag = """ \ + USE_PROFILER=1 \ + USE_CPP_PACKAGE=1 \ + USE_BLAS=openblas \ + USE_OPENMP=0 \ + CXX=clang++-3.9 \ + CC=clang-3.9 \ + -j\$(nproc) + """ + make("cpu_clang", flag) + pack_lib('cpu_clang') + } + } + }, + 'CPU: Clang 5': { + node('mxnetlinux-cpu') { + ws('workspace/build-cpu-clang') { + init_git() + def flag = """ \ + USE_PROFILER=1 \ + USE_CPP_PACKAGE=1 \ + USE_BLAS=openblas \ + USE_OPENMP=1 \ + CXX=clang++-5.0 \ + CC=clang-5.0 \ + -j\$(nproc) + """ + make("cpu_clang", flag) + pack_lib('cpu_clang') + } + } + }, 'CPU: MKLDNN': { node('mxnetlinux-cpu') { ws('workspace/build-mkldnn-cpu') { diff --git a/R-package/tests/testthat/test_model.R b/R-package/tests/testthat/test_model.R index 7707f7157688..8cdd396c2525 100644 --- a/R-package/tests/testthat/test_model.R +++ b/R-package/tests/testthat/test_model.R @@ -172,7 +172,6 @@ test_that("Fine-tune", { }) test_that("Matrix Factorization", { - skip("Disabled due to an unavailible http server. Tracked here: https://git.io/vNkrE") GetMovieLens() DF <- read.table("./data/ml-100k/u.data", header = F, sep = "\t") names(DF) <- c("user", "item", "score", "time") diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h index f3763bbd6e67..320b13eebf2d 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.h +++ b/cpp-package/include/mxnet-cpp/optimizer.h @@ -146,6 +146,20 @@ class SGDOptimizer : public Optimizer { AtomicSymbolCreator mom_update_handle_; }; +class SignumOptimizer : public Optimizer { + public: + explicit SignumOptimizer(unsigned begin_num_update = 0); + std::string GetType() const override; + void Update(int index, NDArray weight, NDArray grad) override; + private: + virtual ~SignumOptimizer(); + void CreateState_(int index, NDArray weight) override; + std::map states_; + AtomicSymbolCreator update_handle_; + AtomicSymbolCreator mom_update_handle_; +}; + + class RMSPropOptimizer : public Optimizer { public: explicit RMSPropOptimizer(unsigned begin_num_update = 0); diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp index cb8442dc9ceb..e3d47d1161c6 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.hpp +++ b/cpp-package/include/mxnet-cpp/optimizer.hpp @@ -131,6 +131,7 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) { MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer); MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer); MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer); + MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer); auto it = cmap().find(name); if (it == cmap().end()) return nullptr; @@ -200,6 +201,69 @@ inline void SGDOptimizer::CreateState_(int index, NDArray weight) { } } +// inplementing Signum optimizer + +inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update) + : Optimizer(begin_num_update) { + update_handle_ = op_map()->GetSymbolCreator("signsgd_update"); + mom_update_handle_ = op_map()->GetSymbolCreator("signum_update"); +} + +inline std::string SignumOptimizer::GetType() const { + return "signum"; +} + +inline SignumOptimizer::~SignumOptimizer() { + for (auto &it : states_) { + delete it.second; + } +} + +inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) { + if (states_.count(index) == 0) { + CreateState_(index, weight); + } + + params_["lr"] = std::to_string(GetLR_(index)); + params_["wd"] = std::to_string(GetWD_(index)); + UpdateCount_(index); + auto keys = GetParamKeys_(); + auto values = GetParamValues_(); + CHECK_EQ(keys.size(), values.size()); + + NDArrayHandle inputs[3]; + inputs[0] = weight.GetHandle(); + inputs[1] = grad.GetHandle(); + + int num_outputs = 1; + NDArrayHandle output = weight.GetHandle(); + NDArrayHandle *outputs = &output; + + if (states_[index] == nullptr) { + MXImperativeInvoke(update_handle_, 2, inputs, + &num_outputs, &outputs, + keys.size(), keys.data(), values.data()); + } else { + inputs[2] = states_[index]->GetHandle(); + MXImperativeInvoke(mom_update_handle_, 3, inputs, + &num_outputs, &outputs, + keys.size(), keys.data(), values.data()); + } +} + +inline void SignumOptimizer::CreateState_(int index, NDArray weight) { + if (params_.count("momentum") == 0) { + states_[index] = nullptr; + } else { + states_[index] = new NDArray(weight.GetShape(), weight.GetContext()); + *states_[index] = 0; + } +} + +// finish implementing Signum + + + inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { update_handle_ = op_map()->GetSymbolCreator("rmsprop_update"); diff --git a/docs/tutorials/speech_recognition/baidu_warp_ctc.md b/docs/tutorials/speech_recognition/baidu_warp_ctc.md deleted file mode 100644 index 6277a19bfde4..000000000000 --- a/docs/tutorials/speech_recognition/baidu_warp_ctc.md +++ /dev/null @@ -1,97 +0,0 @@ -# Using Baidu Warp-CTC with MXNet - - -Baidu-WarpCTC is a CTC implementation by Baidu that supports using GPU processors. It supports using CTC with LSTM to solve label alignment problems in many areas, such as OCR and speech recognition. - -You can get the source code for the example on [GitHub](https://github.com/dmlc/mxnet/tree/master/example/warpctc). - -## Install Baidu Warp-CTC - -``` - cd ~/ - git clone https://github.com/baidu-research/warp-ctc - cd warp-ctc - mkdir build - cd build - cmake .. - make - sudo make install -``` - -## Enable Warp-CTC in MXNet - -``` - comment out following lines in make/config.mk - WARPCTC_PATH = $(HOME)/warp-ctc - MXNET_PLUGINS += plugin/warpctc/warpctc.mk - - rebuild mxnet by - make clean && make -j4 -``` - -## Run Examples - -There are two examples. One is a toy example that validates CTC integration. The second is an OCR example with LSTM and CTC. You can run it by typing the following code: - -``` - cd examples/warpctc - python lstm_ocr.py -``` - -The OCR example is constructed as follows: - -1. It generates a 80x30-pixel image for a 4-digit captcha using a Python captcha library. -2. The 80x30 image is used as 80 input for LSTM, and every input is one column of the image (a 30 dim vector). -3. The output layer use CTC loss. - -The following code shows the detailed construction of the net: - -``` - def lstm_unroll(num_lstm_layer, seq_len, - num_hidden, num_label): - param_cells = [] - last_states = [] - for i in range(num_lstm_layer): - param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), - i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), - h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), - h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) - state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), - h=mx.sym.Variable("l%d_init_h" % i)) - last_states.append(state) - assert(len(last_states) == num_lstm_layer) - data = mx.sym.Variable('data') - label = mx.sym.Variable('label') - - #every column of image is an input, there are seq_len inputs - wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) - hidden_all = [] - for seqidx in range(seq_len): - hidden = wordvec[seqidx] - for i in range(num_lstm_layer): - next_state = lstm(num_hidden, indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=seqidx, layeridx=i) - hidden = next_state.h - last_states[i] = next_state - hidden_all.append(hidden) - hidden_concat = mx.sym.Concat(*hidden_all, dim=0) - pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) - - # here we do NOT need to transpose label as other lstm examples do - label = mx.sym.Reshape(data=label, target_shape=(0,)) - #label should be int type, so use cast - label = mx.sym.Cast(data = label, dtype = 'int32') - sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) - return sm -``` - -## Supporting Multi-label Length - -Provide labels with length b. For samples whose label length is smaller than b, append 0 to the label data to make it have length b. - -0 is reserved for a blank label. - -## Next Steps -* [MXNet tutorials index](http://mxnet.io/tutorials/index.html) diff --git a/docs/tutorials/speech_recognition/ctc.md b/docs/tutorials/speech_recognition/ctc.md new file mode 100644 index 000000000000..9c9a9c98dbd9 --- /dev/null +++ b/docs/tutorials/speech_recognition/ctc.md @@ -0,0 +1,15 @@ +# Connectionist Temporal Classification + +[Connectionist Temporal Classification](https://www.cs.toronto.edu/~graves/icml_2006.pdf) (CTC) is a cost function that is used to train Recurrent Neural Networks (RNNs) to label unsegmented input sequence data in supervised learning. For example, in a speech recognition application, using a typical cross-entropy loss, the input signal needs to be segmented into words or sub-words. However, using CTC-loss, it suffices to provide one label sequence for input sequence and the network learns both the alignment as well labeling. Baidu's warp-ctc page contains a more detailed [introduction to CTC-loss](https://github.com/baidu-research/warp-ctc#introduction). + +## CTC-loss in MXNet +MXNet supports two CTC-loss layers in Symbol API: + +* `mxnet.symbol.contrib.ctc_loss` is implemented in MXNet and included as part of the standard package. +* `mxnet.symbol.WarpCTC` uses Baidu's warp-ctc library and requires building warp-ctc library and mxnet library both from source. + +## LSTM OCR Example +MXNet's example folder contains a [CTC example](https://github.com/apache/incubator-mxnet/tree/master/example/ctc) for using CTC loss with an LSTM network to perform Optical Character Recognition (OCR) prediction on CAPTCHA images. The example demonstrates use of both CTC loss options, as well as inference after training using network symbol and parameter checkpoints. + +## Next Steps +* [MXNet tutorials index](http://mxnet.io/tutorials/index.html) diff --git a/docs/tutorials/speech_recognition/speech_lstm.md b/docs/tutorials/speech_recognition/speech_lstm.md deleted file mode 100644 index 17e2ca0002d6..000000000000 --- a/docs/tutorials/speech_recognition/speech_lstm.md +++ /dev/null @@ -1,156 +0,0 @@ -# Speech LSTM -You can get the source code for these examples on [GitHub](https://github.com/dmlc/mxnet/tree/master/example/speech-demo). - -## Speech Acoustic Modeling Example - -The examples folder contains examples for speech recognition: - -- [lstm_proj.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/lstm_proj.py): Functions for building an LSTM network with and without a projection layer. -- [io_util.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/io_util.py): Wrapper functions for `DataIter` over speech data. -- [train_lstm_proj.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/train_lstm_proj.py): A script for training an LSTM acoustic model. -- [decode_mxnet.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/decode_mxnet.py): A script for decoding an LSTMP acoustic model. -- [default.cfg](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/default.cfg): Configuration for training on the `AMI` SDM1 dataset. You can use it as a template for writing other configuration files. -- [python_wrap](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/python_wrap): C wrappers for Kaldi C++ code, built into an .so file. Python code that loads the .so file and calls the C wrapper functions in `io_func/feat_readers/reader_kaldi.py`. - -Connect to Kaldi: - -- [decode_mxnet.sh](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/decode_mxnet.sh): Called by Kaldi to decode an acoustic model trained by MXNet (select the `simple` method for decoding). - -A full receipt: - -- [run_ami.sh](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/run_ami.sh): A full receipt to train and decode an acoustic model on AMI. It takes features and alignment from Kaldi to train an acoustic model and decode it. - -To create the speech acoustic modeling example, use the following steps. - -### Build Kaldi - -Build Kaldi as shared libraries if you have not already done so. - -```bash -cd kaldi/src -./configure --shared # and other options that you need -make depend -make -``` - -### Build the Python Wrapper - -1. Copy or link the attached `python_wrap` folder to `kaldi/src`. -2. Compile python_wrap/. - -``` -cd kaldi/src/python_wrap/ -make -``` - -### Extract Features and Prepare Frame-level Labels - -The acoustic models use Mel filter-bank or MFCC as input features. They also need to use Kaldi to perform force-alignment to generate frame-level labels from the text transcriptions. For example, if you want to work on the `AMI` data `SDM1`, you can run `kaldi/egs/ami/s5/run_sdm.sh`. Before you can run the examples, you need to configure some paths in `kaldi/egs/ami/s5/cmd.sh` and `kaldi/egs/ami/s5/run_sdm.sh`. Refer to Kaldi's documentation for details. - -The default `run_sdm.sh` script generates the force-alignment labels in their stage 7, and saves the force-aligned labels in `exp/sdm1/tri3a_ali`. The default script generates MFCC features (13-dimensional). You can try training with the MFCC features, or you can create Mel filter-bank features by yourself. For example, you can use a script like this to compute Mel filter-bank features using Kaldi: - -```bash -#!/bin/bash -u - -. ./cmd.sh -. ./path.sh - -# SDM - Single Distant Microphone -micid=1 #which mic from array should be used? -mic=sdm$micid - -# Set bash to 'debug' mode, it prints the commands (option '-x') and exits on : -# -e 'error', -u 'undefined variable', -o pipefail 'error in pipeline', -set -euxo pipefail - -# Path where AMI gets downloaded (or where locally available): -AMI_DIR=$PWD/wav_db # Default, -data_dir=$PWD/data/$mic - -# make filter bank data -for dset in train dev eval; do - steps/make_fbank.sh --nj 48 --cmd "$train_cmd" $data_dir/$dset \ - $data_dir/$dset/log $data_dir/$dset/data-fbank - steps/compute_cmvn_stats.sh $data_dir/$dset \ - $data_dir/$dset/log $data_dir/$dset/data - - apply-cmvn --utt2spk=ark:$data_dir/$dset/utt2spk \ - scp:$data_dir/$dset/cmvn.scp scp:$data_dir/$dset/feats.scp \ - ark,scp:$data_dir/$dset/feats-cmvn.ark,$data_dir/$dset/feats-cmvn.scp - - mv $data_dir/$dset/feats-cmvn.scp $data_dir/$dset/feats.scp -done -``` -`apply-cmvn` provides mean-variance normalization. The default setup was applied per speaker. It's more common to perform mean-variance normalization for the whole corpus, and then feed the results to the neural networks: - -``` - compute-cmvn-stats scp:data/sdm1/train_fbank/feats.scp data/sdm1/train_fbank/cmvn_g.ark - apply-cmvn --norm-vars=true data/sdm1/train_fbank/cmvn_g.ark scp:data/sdm1/train_fbank/feats.scp ark,scp:data/sdm1/train_fbank_gcmvn/feats.ark,data/sdm1/train_fbank_gcmvn/feats.scp -``` -Note that Kaldi always tries to find features in `feats.scp`. Ensure that the normalized features are organized as Kaldi expects them during decoding. - -Finally, put the features and labels together in a file so that MXNet can find them. More specifically, for each data set (train, dev, eval), you will need to create a file similar to `train_mxnet.feats`, with the following contents: - -``` -TRANSFORM scp:feat.scp -scp:label.scp -``` - -`TRANSFORM` is the transformation you want to apply to the features. By default, we use `NO_FEATURE_TRANSFORM`. The `scp:` syntax is from Kaldi. `feat.scp` is typically the file from `data/sdm1/train/feats.scp`, and `label.scp` is converted from the force-aligned labels located in `exp/sdm1/tri3a_ali`. Because the force-alignments are generated only on the training data, we split the training set in two, using a 90/10 ratio, and then use the 1/10 holdout as the dev set (validation set). The script [run_ami.sh](https://github.com/dmlc/mxnet/blob/master/example/speech-demo/run_ami.sh) automatically splits and formats the file for MXNet. Before running it, set the path in the script correctly. The [run_ami.sh](https://github.com/dmlc/mxnet/blob/master/example/speech-demo/run_ami.sh) script actually runs the full pipeline, including training the acoustic model and decoding. If the scripts ran successfully, you can skip the following sections. - -### Run MXNet Acoustic Model Training - -1. Return to the speech demo directory in MXNet. Make a copy of `default.cfg`, and edit the necessary parameters, such as the path to the dataset you just prepared. -2. Run `python train_lstm.py --configfile=your-config.cfg`. For help, use `python train_lstm.py --help`. You can set all of the configuration parameters in `default.cfg`, the customized config file, and through the command line (e.g., using `--train_batch_size=50`). The latter values overwrite the former ones. - -Here are some example outputs from training on the TIMIT dataset: - -``` -Example output for TIMIT: -Summary of dataset ================== -bucket of len 100 : 3 samples -bucket of len 200 : 346 samples -bucket of len 300 : 1496 samples -bucket of len 400 : 974 samples -bucket of len 500 : 420 samples -bucket of len 600 : 90 samples -bucket of len 700 : 11 samples -bucket of len 800 : 2 samples -Summary of dataset ================== -bucket of len 100 : 0 samples -bucket of len 200 : 28 samples -bucket of len 300 : 169 samples -bucket of len 400 : 107 samples -bucket of len 500 : 41 samples -bucket of len 600 : 6 samples -bucket of len 700 : 3 samples -bucket of len 800 : 0 samples -2016-04-21 20:02:40,904 Epoch[0] Train-Acc_exlude_padding=0.154763 -2016-04-21 20:02:40,904 Epoch[0] Time cost=91.574 -2016-04-21 20:02:44,419 Epoch[0] Validation-Acc_exlude_padding=0.353552 -2016-04-21 20:04:17,290 Epoch[1] Train-Acc_exlude_padding=0.447318 -2016-04-21 20:04:17,290 Epoch[1] Time cost=92.870 -2016-04-21 20:04:20,738 Epoch[1] Validation-Acc_exlude_padding=0.506458 -2016-04-21 20:05:53,127 Epoch[2] Train-Acc_exlude_padding=0.557543 -2016-04-21 20:05:53,128 Epoch[2] Time cost=92.390 -2016-04-21 20:05:56,568 Epoch[2] Validation-Acc_exlude_padding=0.548100 -``` - -The final frame accuracy was approximately 62%. - -### Run Decode on the Trained Acoustic Model - -1. Estimate senone priors by running `python make_stats.py --configfile=your-config.cfg | copy-feats ark:- ark:label_mean.ark` (edit necessary items, such as the path to the training dataset). This command generates the label counts in `label_mean.ark`. -2. Link to the necessary Kaldi decode setup, e.g., `local/` and `utils/` and run `./run_ami.sh --model prefix model --num_epoch num`. - -Here are the results for the TIMIT and AMI test sets (using the default setup, three-layer LSTM with projection layers): - - | Corpus | WER | - |--------|-----| - |TIMIT | 18.9| - |AMI | 51.7 (42.2) | - -For AMI 42.2 was evaluated non-overlapped speech. The Kaldi-HMM baseline was 67.2%, and DNN was 57.5%. - -## Next Steps -* [MXNet tutorials index](http://mxnet.io/tutorials/index.html) diff --git a/example/ctc/README.md b/example/ctc/README.md index 9035582a53a3..a2f54cffaf86 100644 --- a/example/ctc/README.md +++ b/example/ctc/README.md @@ -1,80 +1,113 @@ -# CTC with Mxnet +# Connectionist Temporal Classification -## Overview -This example is a modification of [warpctc](https://github.com/dmlc/mxnet/tree/master/example/warpctc) -It demonstrates the usage of ```mx.contrib.sym.ctc_loss``` +[Connectionist Temporal Classification](https://www.cs.toronto.edu/~graves/icml_2006.pdf) (CTC) is a cost function that is used to train Recurrent Neural Networks (RNNs) to label unsegmented input sequence data in supervised learning. For example in a speech recognition application, using a typical cross-entropy loss the input signal needs to be segmented into words or sub-words. However, using CTC-loss, a single unaligned label sequence per input sequence is sufficient for the network to learn both the alignment and labeling. Baidu's warp-ctc page contains a more detailed [introduction to CTC-loss](https://github.com/baidu-research/warp-ctc#introduction). -## Core code change +## LSTM OCR Example +In this example, we use CTC loss to train a network on the problem of Optical Character Recognition (OCR) of CAPTCHA images. This example uses the `captcha` python package to generate a random dataset for training. Training the network requires a CTC-loss layer and MXNet provides two options for such layer. The OCR example is constructed as follows: -The following implementation of ```lstm_unroll``` function is introduced in ```lstm.py``` demonstrates the usage of -```mx.contrib.sym.ctc_loss```. +1. 80x30 CAPTCHA images containing 3 to 4 random digits are generated using python captcha library. +2. Each image is used as a data sequence with sequence-length of 80 and vector length of 30. +3. The output layer uses CTC loss in training and softmax in inference. -```Cython -def lstm_unroll(num_lstm_layer, seq_len, - num_hidden, num_label): - param_cells = [] - last_states = [] - for i in range(num_lstm_layer): - param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), - i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), - h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), - h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) - state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), - h=mx.sym.Variable("l%d_init_h" % i)) - last_states.append(state) - assert (len(last_states) == num_lstm_layer) +Note: When using CTC-loss, one prediction label is reserved for blank label. In this example, when predicting digits between 0 to 9, softmax output has 11 labels, with label 0 used for blank and 1 to 10 used for digit 0 to digit 9 respectively. - # embeding layer - data = mx.sym.Variable('data') - label = mx.sym.Variable('label') - wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) +### Description of the files +LSTM-OCR example contains the following files: +* `captcha_generator.py`: Module for generating random 3 or 4 digit CAPTCHA images for training. It also contains a script for generating sample CAPTCHA images into an output file for inference testing. +* `ctc_metrics.py`: Module for calculating the prediction accuracy during training. Two accuracy measures are implemented: A simple accuracy measure that calculates number of correct predictions divided by total number of predictions and a second accuracy measure based on sum of Longest Common Sequence (LCS) ratio of all predictions divided by total number of predictions. +* `hyperparameters.py`: Contains all hyperparameters for the network structure and training. +* `lstm.py`: Contains LSTM network implementations. Options for adding mxnet-ctc and warp-ctc loss for training as well as adding softmax for inference are available. +* `lstm_ocr_infer.py`: Script for running inference after training. +* `lstm_ocr_train.py`: Script for training with ctc or warp-ctc loss. +* `multiproc_data.py`: A module for multiprocess data generation. +* `oct_iter.py`: A DataIter module for iterating through training data. - hidden_all = [] - for seqidx in range(seq_len): - hidden = wordvec[seqidx] - for i in range(num_lstm_layer): - next_state = lstm(num_hidden, indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=seqidx, layeridx=i) - hidden = next_state.h - last_states[i] = next_state - hidden_all.append(hidden) +## CTC-loss in MXNet +MXNet supports two CTC-loss layers in Symbol API: - hidden_concat = mx.sym.Concat(*hidden_all, dim=0) +* `mxnet.symbol.contrib.ctc_loss` is implemented in MXNet and included as part of the standard package. +* `mxnet.symbol.WarpCTC` uses Baidu's warp-ctc library and requires building warp-ctc library and mxnet library both from source. - pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) - pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0)) +### Building MXNet with warp-ctc +In order to use `mxnet.symbol.WarpCTC` layer, you need to first build Baidu's [warp-ctc](https://github.com/baidu-research/warp-ctc) library from source and then build MXNet from source with warp-ctc config flags enabled. - loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label) - ctc_loss = mx.sym.MakeLoss(loss) +#### Building warp-ctc +You need to first build warp-ctc from source and then install it in your system. Please follow [instructions here](https://github.com/baidu-research/warp-ctc#compilation) to build warp-ctc from source. Once compiled, you need to install the library by running the following command from `warp-ctc/build` directory: +``` +$ sudo make install +``` - softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc) - softmax_loss = mx.sym.MakeLoss(softmax_class) - softmax_loss = mx.sym.BlockGrad(softmax_loss) +#### Building MXNet from source with warp-ctc integration +In order to build MXNet from source, you need to follow [instructions here](http://mxnet.incubator.apache.org/install/index.html). After choosing your system configuration, Python environment, and "Build from Source" options, before running `make` in step 4, you need to enable warp-ctc integration by uncommenting the following lines in `make/config.mk` in `incubator-mxnet` directory: +``` +WARPCTC_PATH = $(HOME)/warp-ctc +MXNET_PLUGINS += plugin/warpctc/warpctc.mk +``` - return mx.sym.Group([softmax_loss, ctc_loss]) +## Run LSTM OCR Example +Running this example requires the following pre-requisites: +* `captcha` and `opencv` python packages are installed: +``` +$ pip install captcha +$ pip install opencv-python +``` +* You have access to one (or more) `ttf` font files. You can download a collection of font files from [Ubuntu's website](https://design.ubuntu.com/font/). The instructions in this section assume that a `./font/Ubuntu-M.ttf` file exists under the `example/ctc/` directory. + +### Training +The training script demonstrates how to construct a network with both CTC loss options and train using `mxnet.Module` API. Training is done by generating random CAPTCHA images using the font(s) provided. This example uses 80x30 captcha images that contain 3 to 4 digits each. + +When using a GPU for training, the training bottleneck will be data generation. To remedy this bottleneck, this example implements a multiprocess data generation. Number of processes for image generation as well as training on CPU or GPU can be configured using command line arguments. + +To see the list of all arguments: +``` +$ python lstm_ocr_train.py --help +``` +Using command line, you can also select between ctc or warp-ctc loss options. For example, the following command initiates a training session on a single GPU with 4 CAPTCHA generating processes using ctc loss and `font/Ubuntu-M.ttf` font file: ``` +$ python lstm_ocr_train.py --gpu 1 --num_proc 4 --loss ctc font/Ubuntu-M.ttf +``` + +You can train with multiple fonts by specifying a folder that contains multiple `ttf` font files instead. The training saves a checkpoint after each epoch. The prefix used for checkpoint is 'ocr' by default, but can be changed with `--prefix` argument. -## Prerequisites +When testing this example, the following system configuration was used: +* p2.xlarge AWS EC2 instance (4 x CPU and 1 x K80 GPU) +* Deep Learning Amazon Machine Image (with mxnet 1.0.0) -Please ensure that following prerequisites are satisfied before running this examples. +This training example finishes after 100 epochs with ~87% accuracy. If you continue training further, the network achieves over 95% accuracy. Similar accuracy is achieved with both ctc (`--loss ctc`) and warp-ctc (`--loss warpctc`) options. Logs of the last training epoch: -- ```captcha``` python package is installed. -- ```cv2``` (or ```openCV```) python package is installed. -- The test requires font file (```ttf``` format). The user either would need to create ```.\data\``` directory and place the font file in that directory. The user can also edit following line to specify path to the font file. -```cython - # you can get this font from http://font.ubuntu.com/ - self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf']) +``` +05:58:36,128 Epoch[99] Batch [50] Speed: 1067.63 samples/sec accuracy=0.877757 +05:58:42,119 Epoch[99] Batch [100] Speed: 1068.14 samples/sec accuracy=0.859688 +05:58:48,114 Epoch[99] Batch [150] Speed: 1067.73 samples/sec accuracy=0.870469 +05:58:54,107 Epoch[99] Batch [200] Speed: 1067.91 samples/sec accuracy=0.864219 +05:58:58,004 Epoch[99] Train-accuracy=0.877367 +05:58:58,005 Epoch[99] Time cost=28.068 +05:58:58,047 Saved checkpoint to "ocr-0100.params" +05:59:00,721 Epoch[99] Validation-accuracy=0.868886 ``` -## How to run +### Inference +The inference script demonstrates how to load a network from a checkpoint, modify its final layer, and predict a label for a CAPTCHA image using `mxnet.Module` API. You can choose the prefix as well as the epoch number of the checkpoint using command line arguments. To see the full list of arguments: +``` +$ python lstm_ocr_infer.py --help +``` +For example, to predict label for 'sample.jpg' file using 'ocr' prefix and checkpoint at epoch 100: +``` +$ python lstm_ocr_infer.py --prefix ocr --epoch 100 sample.jpg -The users would need to run the script ```lstm_ocr.py``` in order to exercise the above code change. -```cython -python lstm_ocr.py -``` +Digits: [0, 0, 8, 9] +``` -## Further reading +Note: The above command expects the following files, generated by the training script, to exist in the current directory: +* ocr-symbol.json +* ocr-0100.params -In order to run the ```ocr_predict.py``` please refer to [ReadMe](https://github.com/apache/incubator-mxnet/blob/master/example/warpctc/README.md) file in [warpctc](https://github.com/dmlc/mxnet/tree/master/example/warpctc) +#### Generate CAPTCHA samples +CAPTCHA images can be generated using the `captcha_generator.py` script. To see the list of all arguments: +``` +$ python captcha_generator.py --help +``` +For example, to generate a CAPTCHA image with random digits from 'font/Ubuntu-M.ttf' and save to 'sample.jpg' file: +``` +$ python captcha_generator.py font/Ubuntu-M.ttf sample.jpg +``` diff --git a/example/ctc/captcha_generator.py b/example/ctc/captcha_generator.py new file mode 100644 index 000000000000..97fab4082ec0 --- /dev/null +++ b/example/ctc/captcha_generator.py @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Helper classes for multiprocess captcha image generation + +This module also provides script for saving captcha images to file using CLI. +""" + +from __future__ import print_function +import random + +from captcha.image import ImageCaptcha +import cv2 +from multiproc_data import MPData +import numpy as np + + +class CaptchaGen(object): + """ + Generates a captcha image + """ + def __init__(self, h, w, font_paths): + """ + Parameters + ---------- + h: int + Height of the generated images + w: int + Width of the generated images + font_paths: list of str + List of all fonts in ttf format + """ + self.captcha = ImageCaptcha(fonts=font_paths) + self.h = h + self.w = w + + def image(self, captcha_str): + """ + Generate a greyscale captcha image representing number string + + Parameters + ---------- + captcha_str: str + string a characters for captcha image + + Returns + ------- + numpy.ndarray + Generated greyscale image in np.ndarray float type with values normalized to [0, 1] + """ + img = self.captcha.generate(captcha_str) + img = np.fromstring(img.getvalue(), dtype='uint8') + img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) + img = cv2.resize(img, (self.h, self.w)) + img = img.transpose(1, 0) + img = np.multiply(img, 1 / 255.0) + return img + + +class DigitCaptcha(object): + """ + Provides shape() and get() interface for digit-captcha image generation + """ + def __init__(self, font_paths, h, w, num_digit_min, num_digit_max): + """ + Parameters + ---------- + font_paths: list of str + List of path to ttf font files + h: int + height of the generated image + w: int + width of the generated image + num_digit_min: int + minimum number of digits in generated captcha image + num_digit_max: int + maximum number of digits in generated captcha image + """ + self.num_digit_min = num_digit_min + self.num_digit_max = num_digit_max + self.captcha = CaptchaGen(h=h, w=w, font_paths=font_paths) + + @property + def shape(self): + """ + Returns shape of the image data generated + + Returns + ------- + tuple(int, int) + """ + return self.captcha.h, self.captcha.w + + def get(self): + """ + Get an image from the queue + + Returns + ------- + np.ndarray + A captcha image, normalized to [0, 1] + """ + return self._gen_sample() + + @staticmethod + def get_rand(num_digit_min, num_digit_max): + """ + Generates a character string of digits. Number of digits are + between self.num_digit_min and self.num_digit_max + Returns + ------- + str + """ + buf = "" + max_len = random.randint(num_digit_min, num_digit_max) + for i in range(max_len): + buf += str(random.randint(0, 9)) + return buf + + def _gen_sample(self): + """ + Generate a random captcha image sample + Returns + ------- + (numpy.ndarray, str) + Tuple of image (numpy ndarray) and character string of digits used to generate the image + """ + num_str = self.get_rand(self.num_digit_min, self.num_digit_max) + return self.captcha.image(num_str), num_str + + +class MPDigitCaptcha(DigitCaptcha): + """ + Handles multi-process captcha image generation + """ + def __init__(self, font_paths, h, w, num_digit_min, num_digit_max, num_processes, max_queue_size): + """ + + Parameters + ---------- + font_paths: list of str + List of path to ttf font files + h: int + height of the generated image + w: int + width of the generated image + num_digit_min: int + minimum number of digits in generated captcha image + num_digit_max: int + maximum number of digits in generated captcha image + num_processes: int + Number of processes to spawn + max_queue_size: int + Maximum images in queue before processes wait + """ + super(MPDigitCaptcha, self).__init__(font_paths, h, w, num_digit_min, num_digit_max) + self.mp_data = MPData(num_processes, max_queue_size, self._gen_sample) + + def start(self): + """ + Starts the processes + """ + self.mp_data.start() + + def get(self): + """ + Get an image from the queue + + Returns + ------- + np.ndarray + A captcha image, normalized to [0, 1] + """ + return self.mp_data.get() + + def reset(self): + """ + Resets the generator by stopping all processes + """ + self.mp_data.reset() + + +if __name__ == '__main__': + import argparse + + def main(): + parser = argparse.ArgumentParser() + parser.add_argument("font_path", help="Path to ttf font file") + parser.add_argument("output", help="Output filename including extension (e.g. 'sample.jpg')") + parser.add_argument("--num", help="Up to 4 digit number [Default: random]") + args = parser.parse_args() + + captcha = ImageCaptcha(fonts=[args.font_path]) + captcha_str = args.num if args.num else DigitCaptcha.get_rand(3, 4) + img = captcha.generate(captcha_str) + img = np.fromstring(img.getvalue(), dtype='uint8') + img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) + cv2.imwrite(args.output, img) + print("Captcha image with digits {} written to {}".format([int(c) for c in captcha_str], args.output)) + + main() diff --git a/example/ctc/ctc_metrics.py b/example/ctc/ctc_metrics.py new file mode 100644 index 000000000000..0db680af18d7 --- /dev/null +++ b/example/ctc/ctc_metrics.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Contains a class for calculating CTC eval metrics""" + +from __future__ import print_function + +import numpy as np + + +class CtcMetrics(object): + def __init__(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def ctc_label(p): + """ + Iterates through p, identifying non-zero and non-repeating values, and returns them in a list + Parameters + ---------- + p: list of int + + Returns + ------- + list of int + """ + ret = [] + p1 = [0] + p + for i, _ in enumerate(p): + c1 = p1[i] + c2 = p1[i+1] + if c2 == 0 or c2 == c1: + continue + ret.append(c2) + return ret + + @staticmethod + def _remove_blank(l): + """ Removes trailing zeros in the list of integers and returns a new list of integers""" + ret = [] + for i, _ in enumerate(l): + if l[i] == 0: + break + ret.append(l[i]) + return ret + + @staticmethod + def _lcs(p, l): + """ Calculates the Longest Common Subsequence between p and l (both list of int) and returns its length""" + # Dynamic Programming Finding LCS + if len(p) == 0: + return 0 + P = np.array(list(p)).reshape((1, len(p))) + L = np.array(list(l)).reshape((len(l), 1)) + M = np.int32(P == L) + for i in range(M.shape[0]): + for j in range(M.shape[1]): + up = 0 if i == 0 else M[i-1, j] + left = 0 if j == 0 else M[i, j-1] + M[i, j] = max(up, left, M[i, j] if (i == 0 or j == 0) else M[i, j] + M[i-1, j-1]) + return M.max() + + def accuracy(self, label, pred): + """ Simple accuracy measure: number of 100% accurate predictions divided by total number """ + hit = 0. + total = 0. + batch_size = label.shape[0] + for i in range(batch_size): + l = self._remove_blank(label[i]) + p = [] + for k in range(self.seq_len): + p.append(np.argmax(pred[k * batch_size + i])) + p = self.ctc_label(p) + if len(p) == len(l): + match = True + for k, _ in enumerate(p): + if p[k] != int(l[k]): + match = False + break + if match: + hit += 1.0 + total += 1.0 + assert total == batch_size + return hit / total + + def accuracy_lcs(self, label, pred): + """ Longest Common Subsequence accuracy measure: calculate accuracy of each prediction as LCS/length""" + hit = 0. + total = 0. + batch_size = label.shape[0] + for i in range(batch_size): + l = self._remove_blank(label[i]) + p = [] + for k in range(self.seq_len): + p.append(np.argmax(pred[k * batch_size + i])) + p = self.ctc_label(p) + hit += self._lcs(p, l) * 1.0 / len(l) + total += 1.0 + assert total == batch_size + return hit / total + diff --git a/example/ctc/hyperparams.py b/example/ctc/hyperparams.py new file mode 100644 index 000000000000..7289d19c03f1 --- /dev/null +++ b/example/ctc/hyperparams.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Hyperparameters for LSTM OCR Example """ + +from __future__ import print_function + + +class Hyperparams(object): + """ + Hyperparameters for LSTM network + """ + def __init__(self): + # Training hyper parameters + self._train_epoch_size = 30000 + self._eval_epoch_size = 3000 + self._batch_size = 128 + self._num_epoch = 100 + self._learning_rate = 0.001 + self._momentum = 0.9 + self._num_label = 4 + # Network hyper parameters + self._seq_length = 80 + self._num_hidden = 100 + self._num_lstm_layer = 2 + + @property + def train_epoch_size(self): + return self._train_epoch_size + + @property + def eval_epoch_size(self): + return self._eval_epoch_size + + @property + def batch_size(self): + return self._batch_size + + @property + def num_epoch(self): + return self._num_epoch + + @property + def learning_rate(self): + return self._learning_rate + + @property + def momentum(self): + return self._momentum + + @property + def num_label(self): + return self._num_label + + @property + def seq_length(self): + return self._seq_length + + @property + def num_hidden(self): + return self._num_hidden + + @property + def num_lstm_layer(self): + return self._num_lstm_layer diff --git a/example/ctc/lstm.py b/example/ctc/lstm.py index 326daa1d9f3a..dcf8b4e4ef74 100644 --- a/example/ctc/lstm.py +++ b/example/ctc/lstm.py @@ -14,29 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Contain helpers for creating LSTM symbolic graph for training and inference """ -# pylint:skip-file -import sys +from __future__ import print_function -from mxnet.symbol_doc import SymbolDoc +from collections import namedtuple -sys.path.insert(0, "../../python") import mxnet as mx -import numpy as np -from collections import namedtuple -import time -import math + + +__all__ = ["lstm_unroll", "init_states"] + LSTMState = namedtuple("LSTMState", ["c", "h"]) LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", "h2h_weight", "h2h_bias"]) -LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", - "init_states", "last_states", - "seq_data", "seq_labels", "seq_outputs", - "param_blocks"]) -def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): +def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): """LSTM Cell symbol""" i2h = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight, @@ -60,8 +55,8 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): return LSTMState(c=next_c, h=next_h) -def lstm_unroll(num_lstm_layer, seq_len, - num_hidden, num_label): +def _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden): + """ Returns symbol for LSTM model up to loss/softmax""" param_cells = [] last_states = [] for i in range(num_lstm_layer): @@ -72,35 +67,108 @@ def lstm_unroll(num_lstm_layer, seq_len, state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), h=mx.sym.Variable("l%d_init_h" % i)) last_states.append(state) - assert (len(last_states) == num_lstm_layer) + assert len(last_states) == num_lstm_layer - # embeding layer + # embedding layer data = mx.sym.Variable('data') - label = mx.sym.Variable('label') wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) hidden_all = [] for seqidx in range(seq_len): hidden = wordvec[seqidx] for i in range(num_lstm_layer): - next_state = lstm(num_hidden, indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=seqidx, layeridx=i) + next_state = _lstm( + num_hidden=num_hidden, + indata=hidden, + prev_state=last_states[i], + param=param_cells[i], + seqidx=seqidx, + layeridx=i) hidden = next_state.h last_states[i] = next_state hidden_all.append(hidden) hidden_concat = mx.sym.Concat(*hidden_all, dim=0) + pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11, name="pred_fc") + return pred_fc - pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) - pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0)) + +def _add_warp_ctc_loss(pred, seq_len, num_label, label): + """ Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """ + label = mx.sym.Reshape(data=label, shape=(-1,)) + label = mx.sym.Cast(data=label, dtype='int32') + return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len) + + +def _add_mxnet_ctc_loss(pred, seq_len, label): + """ Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """ + pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0)) loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label) ctc_loss = mx.sym.MakeLoss(loss) - softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc) + softmax_class = mx.symbol.SoftmaxActivation(data=pred) softmax_loss = mx.sym.MakeLoss(softmax_class) softmax_loss = mx.sym.BlockGrad(softmax_loss) - return mx.sym.Group([softmax_loss, ctc_loss]) + + +def _add_ctc_loss(pred, seq_len, num_label, loss_type): + """ Adds CTC loss on top of pred symbol and returns the resulting symbol """ + label = mx.sym.Variable('label') + if loss_type == 'warpctc': + print("Using WarpCTC Loss") + sm = _add_warp_ctc_loss(pred, seq_len, num_label, label) + else: + print("Using MXNet CTC Loss") + assert loss_type == 'ctc' + sm = _add_mxnet_ctc_loss(pred, seq_len, label) + return sm + + +def lstm_unroll(num_lstm_layer, seq_len, num_hidden, num_label, loss_type=None): + """ + Creates an unrolled LSTM symbol for inference if loss_type is not specified, and for training + if loss_type is specified. loss_type must be one of 'ctc' or 'warpctc' + + Parameters + ---------- + num_lstm_layer: int + seq_len: int + num_hidden: int + num_label: int + loss_type: str + 'ctc' or 'warpctc' + + Returns + ------- + mxnet.symbol.symbol.Symbol + """ + # Create the base (shared between training and inference) and add loss to the end + pred = _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden) + + if loss_type: + # Training mode, add loss + return _add_ctc_loss(pred, seq_len, num_label, loss_type) + else: + # Inference mode, add softmax + return mx.sym.softmax(data=pred, name='softmax') + + +def init_states(batch_size, num_lstm_layer, num_hidden): + """ + Returns name and shape of init states of LSTM network + + Parameters + ---------- + batch_size: list of tuple of str and tuple of int and int + num_lstm_layer: int + num_hidden: int + + Returns + ------- + list of tuple of str and tuple of int and int + """ + init_c = [('l%d_init_c' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] + init_h = [('l%d_init_h' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] + return init_c + init_h diff --git a/example/ctc/lstm_ocr.py b/example/ctc/lstm_ocr.py deleted file mode 100644 index c9928aa43ab8..000000000000 --- a/example/ctc/lstm_ocr.py +++ /dev/null @@ -1,254 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme -# pylint: disable=superfluous-parens, no-member, invalid-name -from __future__ import print_function -import sys, random -sys.path.insert(0, "../../python") -import numpy as np -import mxnet as mx - -from lstm import lstm_unroll - -from captcha.image import ImageCaptcha -import cv2, random - - -class SimpleBatch(object): - def __init__(self, data_names, data, label_names, label): - self.data = data - self.label = label - self.data_names = data_names - self.label_names = label_names - - self.pad = 0 - self.index = None # TODO: what is index? - - @property - def provide_data(self): - return [(n, x.shape) for n, x in zip(self.data_names, self.data)] - - @property - def provide_label(self): - return [(n, x.shape) for n, x in zip(self.label_names, self.label)] - - -def gen_rand(): - buf = "" - max_len = random.randint(3, 4) - for i in range(max_len): - buf += str(random.randint(0, 9)) - return buf - - -def get_label(buf): - ret = np.zeros(4) - for i in range(len(buf)): - ret[i] = 1 + int(buf[i]) - if len(buf) == 3: - ret[3] = 0 - return ret - - -class OCRIter(mx.io.DataIter): - def __init__(self, count, batch_size, num_label, init_states): - super(OCRIter, self).__init__() - global SEQ_LENGTH - # you can get this font from http://font.ubuntu.com/ - self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf']) - self.batch_size = batch_size - self.count = count - self.num_label = num_label - self.init_states = init_states - self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] - self.provide_data = [('data', (batch_size, 80, 30))] + init_states - self.provide_label = [('label', (self.batch_size, 4))] - self.cache_data = [] - self.cache_label = [] - - def __iter__(self): - print('iter') - init_state_names = [x[0] for x in self.init_states] - for k in range(self.count): - data = [] - label = [] - for i in range(self.batch_size): - num = gen_rand() - img = self.captcha.generate(num) - img = np.fromstring(img.getvalue(), dtype='uint8') - img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) - img = cv2.resize(img, (80, 30)) - img = img.transpose(1, 0) - img = img.reshape((80, 30)) - img = np.multiply(img, 1 / 255.0) - data.append(img) - label.append(get_label(num)) - - data_all = [mx.nd.array(data)] + self.init_state_arrays - label_all = [mx.nd.array(label)] - data_names = ['data'] + init_state_names - label_names = ['label'] - - data_batch = SimpleBatch(data_names, data_all, label_names, label_all) - yield data_batch - - def reset(self): - self.cache_data.clear() - self.cache_label.clear() - pass - - -BATCH_SIZE = 1024 -SEQ_LENGTH = 80 - - -def ctc_label(p): - ret = [] - p1 = [0] + p - for i in range(len(p)): - c1 = p1[i] - c2 = p1[i + 1] - if c2 == 0 or c2 == c1: - continue - ret.append(c2) - return ret - - -def remove_blank(l): - ret = [] - for i in range(len(l)): - if l[i] == 0: - break - ret.append(l[i]) - return ret - - -def Accuracy(label, pred): - global BATCH_SIZE - global SEQ_LENGTH - hit = 0. - total = 0. - rp = np.argmax(pred, axis=1) - for i in range(BATCH_SIZE): - l = remove_blank(label[i]) - p = [] - for k in range(SEQ_LENGTH): - p.append(np.argmax(pred[k * BATCH_SIZE + i])) - p = ctc_label(p) - if len(p) == len(l): - match = True - for k in range(len(p)): - if p[k] != int(l[k]): - match = False - break - if match: - hit += 1.0 - total += 1.0 - return hit / total - - -def LCS(p, l): - # Dynamic Programming Finding LCS - if len(p) == 0: - return 0 - P = np.array(list(p)).reshape((1, len(p))) - L = np.array(list(l)).reshape((len(l), 1)) - M = np.int32(P == L) - for i in range(M.shape[0]): - for j in range(M.shape[1]): - up = 0 if i == 0 else M[i - 1, j] - left = 0 if j == 0 else M[i, j - 1] - M[i, j] = max(up, left, M[i, j] if (i == 0 or j == 0) else M[i, j] + M[i - 1, j - 1]) - return M.max() - - -def Accuracy_LCS(label, pred): - global BATCH_SIZE - global SEQ_LENGTH - hit = 0. - total = 0. - for i in range(BATCH_SIZE): - l = remove_blank(label[i]) - p = [] - for k in range(SEQ_LENGTH): - p.append(np.argmax(pred[k * BATCH_SIZE + i])) - p = ctc_label(p) - hit += LCS(p, l) * 1.0 / len(l) - total += 1.0 - return hit / total - - -def asum_stat(x): - """returns |x|/size(x), async execution.""" - # npx = x.asnumpy() - # print(npx) - return x - return mx.ndarray.norm(x) / np.sqrt(x.size) - - -if __name__ == '__main__': - num_hidden = 100 - num_lstm_layer = 2 - - num_epoch = 100 - learning_rate = 0.01 - momentum = 0.9 - num_label = 4 - - contexts = [mx.context.gpu(0)] - - - def sym_gen(seq_len): - return lstm_unroll(num_lstm_layer, seq_len, - num_hidden=num_hidden, - num_label=num_label) - - - init_c = [('l%d_init_c' % l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] - init_h = [('l%d_init_h' % l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] - init_states = init_c + init_h - - data_train = OCRIter(20000, BATCH_SIZE, num_label, init_states) - data_val = OCRIter(1000, BATCH_SIZE, num_label, init_states) - - symbol = sym_gen(SEQ_LENGTH) - - import logging - - head = '%(asctime)-15s %(message)s' - logging.basicConfig(level=logging.DEBUG, format=head) - - print('begin fit') - - module = mx.mod.Module(symbol, data_names=['data', 'l0_init_c', 'l0_init_h', 'l1_init_c', 'l1_init_h'], - label_names=['label'], - context=contexts) - - module.fit(train_data=data_train, - eval_data=data_val, - eval_metric=mx.metric.np(Accuracy, allow_extra_outputs=True), - optimizer='sgd', - optimizer_params={'learning_rate': learning_rate, - 'momentum': momentum, - 'wd': 0.00001, - }, - initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), - num_epoch=num_epoch, - batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50), - epoch_end_callback=mx.callback.do_checkpoint("ocr"), - ) diff --git a/example/ctc/lstm_ocr_infer.py b/example/ctc/lstm_ocr_infer.py new file mode 100644 index 000000000000..80de2c7efac4 --- /dev/null +++ b/example/ctc/lstm_ocr_infer.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss""" + +from __future__ import print_function + +import argparse + +from ctc_metrics import CtcMetrics +import cv2 +from hyperparams import Hyperparams +import lstm +import mxnet as mx +import numpy as np +from ocr_iter import SimpleBatch + + +def read_img(path): + """ Reads image specified by path into numpy.ndarray""" + img = cv2.resize(cv2.imread(path, 0), (80, 30)).astype(np.float32) / 255 + img = np.expand_dims(img.transpose(1, 0), 0) + return img + + +def lstm_init_states(batch_size): + """ Returns a tuple of names and zero arrays for LSTM init states""" + hp = Hyperparams() + init_shapes = lstm.init_states(batch_size=batch_size, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden) + init_names = [s[0] for s in init_shapes] + init_arrays = [mx.nd.zeros(x[1]) for x in init_shapes] + return init_names, init_arrays + + +def load_module(prefix, epoch, data_names, data_shapes): + """ + Loads the model from checkpoint specified by prefix and epoch, binds it + to an executor, and sets its parameters and returns a mx.mod.Module + """ + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + + # We don't need CTC loss for prediction, just a simple softmax will suffice. + # We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top + pred_fc = sym.get_internals()['pred_fc_output'] + sym = mx.sym.softmax(data=pred_fc) + + mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None) + mod.bind(for_training=False, data_shapes=data_shapes) + mod.set_params(arg_params, aux_params, allow_missing=False) + return mod + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("path", help="Path to the CAPTCHA image file") + parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr') + parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100) + args = parser.parse_args() + + init_state_names, init_state_arrays = lstm_init_states(batch_size=1) + img = read_img(args.path) + + sample = SimpleBatch( + data_names=['data'] + init_state_names, + data=[mx.nd.array(img)] + init_state_arrays) + + mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data) + + mod.forward(sample) + prob = mod.get_outputs()[0].asnumpy() + + prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist()) + # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit) + prediction = [p - 1 for p in prediction] + print("Digits:", prediction) + return + + +if __name__ == '__main__': + main() diff --git a/example/ctc/lstm_ocr_train.py b/example/ctc/lstm_ocr_train.py new file mode 100644 index 000000000000..2c25f7e31e11 --- /dev/null +++ b/example/ctc/lstm_ocr_train.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" An example of using WarpCTC loss for an OCR problem using LSTM and CAPTCHA image data""" + +from __future__ import print_function + +import argparse +import logging +import os + +from captcha_generator import MPDigitCaptcha +from hyperparams import Hyperparams +from ctc_metrics import CtcMetrics +import lstm +import mxnet as mx +from ocr_iter import OCRIter + + +def get_fonts(path): + fonts = list() + if os.path.isdir(path): + for filename in os.listdir(path): + if filename.endswith('.ttf'): + fonts.append(os.path.join(path, filename)) + else: + fonts.append(path) + return fonts + + +def parse_args(): + # Parse command line arguments + parser = argparse.ArgumentParser() + parser.add_argument("font_path", help="Path to ttf font file or directory containing ttf files") + parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc') + parser.add_argument("--cpu", + help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.", + type=int, default=8) + parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int) + parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4) + parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr') + return parser.parse_args() + + +def main(): + args = parse_args() + if not any(args.loss == s for s in ['ctc', 'warpctc']): + raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss)) + + hp = Hyperparams() + + # Start a multiprocessor captcha image generator + mp_captcha = MPDigitCaptcha( + font_paths=get_fonts(args.font_path), h=hp.seq_length, w=30, + num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2) + try: + # Must call start() before any call to mxnet module (https://github.com/apache/incubator-mxnet/issues/9213) + mp_captcha.start() + + if args.gpu: + contexts = [mx.context.gpu(i) for i in range(args.gpu)] + else: + contexts = [mx.context.cpu(i) for i in range(args.cpu)] + + init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer, hp.num_hidden) + + data_train = OCRIter( + hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='train') + data_val = OCRIter( + hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='val') + + symbol = lstm.lstm_unroll( + num_lstm_layer=hp.num_lstm_layer, + seq_len=hp.seq_length, + num_hidden=hp.num_hidden, + num_label=hp.num_label, + loss_type=args.loss) + + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.DEBUG, format=head) + + module = mx.mod.Module( + symbol, + data_names=['data', 'l0_init_c', 'l0_init_h', 'l1_init_c', 'l1_init_h'], + label_names=['label'], + context=contexts) + + metrics = CtcMetrics(hp.seq_length) + module.fit(train_data=data_train, + eval_data=data_val, + # use metrics.accuracy or metrics.accuracy_lcs + eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), + optimizer='sgd', + optimizer_params={'learning_rate': hp.learning_rate, + 'momentum': hp.momentum, + 'wd': 0.00001, + }, + initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), + num_epoch=hp.num_epoch, + batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50), + epoch_end_callback=mx.callback.do_checkpoint(args.prefix), + ) + except KeyboardInterrupt: + print("W: interrupt received, stopping...") + finally: + # Reset multiprocessing captcha generator to stop processes + mp_captcha.reset() + + +if __name__ == '__main__': + main() + diff --git a/example/ctc/multiproc_data.py b/example/ctc/multiproc_data.py new file mode 100644 index 000000000000..c5f8da56355a --- /dev/null +++ b/example/ctc/multiproc_data.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import print_function +from ctypes import c_bool +import multiprocessing as mp +try: + from queue import Full as QFullExcept + from queue import Empty as QEmptyExcept +except ImportError: + from Queue import Full as QFullExcept + from Queue import Empty as QEmptyExcept + +import numpy as np + + +class MPData(object): + """ + Handles multi-process data generation. + + Operation: + - call start() to start the data generation + - call get() (blocking) to read one sample + - call reset() to stop data generation + """ + def __init__(self, num_processes, max_queue_size, fn): + """ + + Parameters + ---------- + num_processes: int + Number of processes to spawn + max_queue_size: int + Maximum samples in the queue before processes wait + fn: function + function that generates samples, executed on separate processes. + """ + self.queue = mp.Queue(maxsize=int(max_queue_size)) + self.alive = mp.Value(c_bool, False, lock=False) + self.num_proc = num_processes + self.proc = list() + self.fn = fn + + def start(self): + """ + Starts the processes + Parameters + ---------- + fn: function + + """ + """ + Starts the processes + """ + self._init_proc() + + @staticmethod + def _proc_loop(proc_id, alive, queue, fn): + """ + Thread loop for generating data + + Parameters + ---------- + proc_id: int + Process id + alive: multiprocessing.Value + variable for signaling whether process should continue or not + queue: multiprocessing.Queue + queue for passing data back + fn: function + function object that returns a sample to be pushed into the queue + """ + print("proc {} started".format(proc_id)) + try: + while alive.value: + data = fn() + put_success = False + while alive.value and not put_success: + try: + queue.put(data, timeout=0.5) + put_success = True + except QFullExcept: + # print("Queue Full") + pass + except KeyboardInterrupt: + print("W: interrupt received, stopping process {} ...".format(proc_id)) + print("Closing process {}".format(proc_id)) + queue.close() + + def _init_proc(self): + """ + Start processes if not already started + """ + if not self.proc: + self.proc = [ + mp.Process(target=self._proc_loop, args=(i, self.alive, self.queue, self.fn)) + for i in range(self.num_proc) + ] + self.alive.value = True + for p in self.proc: + p.start() + + def get(self): + """ + Get a datum from the queue + + Returns + ------- + np.ndarray + A captcha image, normalized to [0, 1] + """ + self._init_proc() + return self.queue.get() + + def reset(self): + """ + Resets the generator by stopping all processes + """ + self.alive.value = False + qsize = 0 + try: + while True: + self.queue.get(timeout=0.1) + qsize += 1 + except QEmptyExcept: + pass + print("Queue size on reset: {}".format(qsize)) + for i, p in enumerate(self.proc): + p.join() + self.proc.clear() diff --git a/example/ctc/ocr_iter.py b/example/ctc/ocr_iter.py new file mode 100644 index 000000000000..1432e92a80fd --- /dev/null +++ b/example/ctc/ocr_iter.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Iterator for Captcha images used for LSTM-based OCR model""" + +from __future__ import print_function + +import numpy as np +import mxnet as mx + + +class SimpleBatch(object): + def __init__(self, data_names, data, label_names=list(), label=list()): + self._data = data + self._label = label + self._data_names = data_names + self._label_names = label_names + + self.pad = 0 + self.index = None # TODO: what is index? + + @property + def data(self): + return self._data + + @property + def label(self): + return self._label + + @property + def data_names(self): + return self._data_names + + @property + def label_names(self): + return self._label_names + + @property + def provide_data(self): + return [(n, x.shape) for n, x in zip(self._data_names, self._data)] + + @property + def provide_label(self): + return [(n, x.shape) for n, x in zip(self._label_names, self._label)] + + +def get_label(buf): + ret = np.zeros(4) + for i in range(len(buf)): + ret[i] = 1 + int(buf[i]) + if len(buf) == 3: + ret[3] = 0 + return ret + + +class OCRIter(mx.io.DataIter): + """ + Iterator class for generating captcha image data + """ + def __init__(self, count, batch_size, lstm_init_states, captcha, name): + """ + Parameters + ---------- + count: int + Number of batches to produce for one epoch + batch_size: int + lstm_init_states: list of tuple(str, tuple) + A list of tuples with [0] name and [1] shape of each LSTM init state + captcha MPCaptcha + Captcha image generator. Can be MPCaptcha or any other class providing .shape and .get() interface + name: str + """ + super(OCRIter, self).__init__() + self.batch_size = batch_size + self.count = count + self.init_states = lstm_init_states + self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states] + data_shape = captcha.shape + self.provide_data = [('data', (batch_size, data_shape[0], data_shape[1]))] + lstm_init_states + self.provide_label = [('label', (self.batch_size, 4))] + self.mp_captcha = captcha + self.name = name + + def __iter__(self): + init_state_names = [x[0] for x in self.init_states] + for k in range(self.count): + data = [] + label = [] + for i in range(self.batch_size): + img, num = self.mp_captcha.get() + data.append(img) + label.append(get_label(num)) + data_all = [mx.nd.array(data)] + self.init_state_arrays + label_all = [mx.nd.array(label)] + data_names = ['data'] + init_state_names + label_names = ['label'] + + data_batch = SimpleBatch(data_names, data_all, label_names, label_all) + yield data_batch diff --git a/example/ctc/ocr_predict.py b/example/ctc/ocr_predict.py index 3096a664a20f..2cf19678f4b5 100644 --- a/example/ctc/ocr_predict.py +++ b/example/ctc/ocr_predict.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python2.7 - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,24 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss""" -# coding=utf-8 from __future__ import print_function -import sys, os -curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) -sys.path.append("../../amalgamation/python/") -sys.path.append("../../python/") -from mxnet_predict import Predictor -import mxnet as mx +import argparse -import numpy as np +import sys import cv2 -import os +import numpy as np +import mxnet as mx +from collections import namedtuple +from ocr_iter import SimpleBatch +from captcha_generator import DigitCaptcha +from ctc_metrics import CtcMetrics +import lstm +from hyperparams import Hyperparams + class lstm_ocr_model(object): # Keep Zero index for blank. (CTC request it) - CONST_CHAR='0123456789' + CONST_CHAR = '0123456789' + def __init__(self, path_of_json, path_of_params): super(lstm_ocr_model, self).__init__() self.path_of_json = path_of_json @@ -52,32 +54,37 @@ def __init_ocr(self): init_states = init_c + init_h init_state_arrays = np.zeros((batch_size, num_hidden), dtype="float32") - self.init_state_dict={} + self.init_state_dict = {} for x in init_states: self.init_state_dict[x[0]] = init_state_arrays - all_shapes = [('data', (batch_size, 80 * 30))] + init_states + [('label', (batch_size, num_label))] + all_shapes = [('data', (batch_size, 80, 30))] + init_states + [('label', (batch_size, num_label))] all_shapes_dict = {} for _shape in all_shapes: all_shapes_dict[_shape[0]] = _shape[1] - self.predictor = Predictor(open(self.path_of_json).read(), - open(self.path_of_params).read(), - all_shapes_dict) - - def forward_ocr(self, img): - img = cv2.resize(img, (80, 30)) - img = img.transpose(1, 0) - img = img.reshape((80 * 30)) - img = np.multiply(img, 1/255.0) - self.predictor.forward(data=img, **self.init_state_dict) + self.predictor = Predictor(open(self.path_of_json, 'rb').read(), + open(self.path_of_params, 'rb').read(), + all_shapes_dict) + + def forward_ocr(self, img_): + img_ = cv2.resize(img_, (80, 30)) + img_ = img_.transpose(1, 0) + print(img_.shape) + img_ = img_.reshape((1, 80, 30)) + print(img_.shape) + # img_ = img_.reshape((80 * 30)) + img_ = np.multiply(img_, 1 / 255.0) + self.predictor.forward(data=img_, **self.init_state_dict) prob = self.predictor.get_output(0) label_list = [] for p in prob: + print(np.argsort(p)) max_index = np.argsort(p)[::-1][0] label_list.append(max_index) return self.__get_string(label_list) - def __get_string(self, label_list): + @staticmethod + def __get_string(label_list): # Do CTC label rule # CTC cannot emit a repeated symbol on consecutive timesteps ret = [] @@ -98,9 +105,55 @@ def __get_string(self, label_list): s += c return s + if __name__ == '__main__': + # parser = argparse.ArgumentParser() + # parser.add_argument("path", help="Path to the CAPTCHA image file") + # parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr') + # parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100) + # args = parser.parse_args() + # + # # Create array of zeros for LSTM init states + # hp = Hyperparams() + # init_states = lstm.init_states(batch_size=1, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden) + # init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] + # # Read the image into an ndarray + # img = cv2.resize(cv2.imread(args.path, 0), (80, 30)).astype(np.float32) / 255 + # img = np.expand_dims(img.transpose(1, 0), 0) + # + # data_names = ['data'] + [s[0] for s in init_states] + # sample = SimpleBatch(data_names, data=[mx.nd.array(img)] + init_state_arrays) + # + # sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch) + # + # # We don't need CTC loss for prediction, just a simple softmax will suffice. + # # We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top + # pred_fc = sym.get_internals()['pred_fc_output'] + # sym = mx.sym.softmax(data=pred_fc) + # + # mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None) + # mod.bind(for_training=False, data_shapes=sample.provide_data) + # mod.set_params(arg_params, aux_params, allow_missing=False) + # + # mod.forward(sample) + # prob = mod.get_outputs()[0].asnumpy() + # + # label_list = list() + # prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist()) + # # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit) + # prediction = [p - 1 for p in prediction] + # print("Digits:", prediction) + # exit(0) + # + + parser = argparse.ArgumentParser() + parser.add_argument("predict_lib_path", help="Path to directory containing mxnet_predict.so") + args = parser.parse_args() + + sys.path.append(args.predict_lib_path + "/python") + from mxnet_predict import Predictor + _lstm_ocr_model = lstm_ocr_model('ocr-symbol.json', 'ocr-0010.params') - img = cv2.imread('sample.jpg', 0) + img = cv2.imread('sample0.png', 0) _str = _lstm_ocr_model.forward_ocr(img) print('Result: ', _str) - diff --git a/example/warpctc/sample.jpg b/example/ctc/sample.jpg similarity index 100% rename from example/warpctc/sample.jpg rename to example/ctc/sample.jpg diff --git a/example/deep-embedded-clustering/autoencoder.py b/example/deep-embedded-clustering/autoencoder.py new file mode 100644 index 000000000000..096f04529c3b --- /dev/null +++ b/example/deep-embedded-clustering/autoencoder.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring, arguments-differ +from __future__ import print_function + +import logging + +import mxnet as mx +import numpy as np +import model +from solver import Solver, Monitor + + +class AutoEncoderModel(model.MXModel): + def setup(self, dims, sparseness_penalty=None, pt_dropout=None, + ft_dropout=None, input_act=None, internal_act='relu', output_act=None): + self.N = len(dims) - 1 + self.dims = dims + self.stacks = [] + self.pt_dropout = pt_dropout + self.ft_dropout = ft_dropout + self.input_act = input_act + self.internal_act = internal_act + self.output_act = output_act + + self.data = mx.symbol.Variable('data') + for i in range(self.N): + if i == 0: + decoder_act = input_act + idropout = None + else: + decoder_act = internal_act + idropout = pt_dropout + if i == self.N-1: + encoder_act = output_act + odropout = None + else: + encoder_act = internal_act + odropout = pt_dropout + istack, iargs, iargs_grad, iargs_mult, iauxs = self.make_stack( + i, self.data, dims[i], dims[i+1], sparseness_penalty, + idropout, odropout, encoder_act, decoder_act + ) + self.stacks.append(istack) + self.args.update(iargs) + self.args_grad.update(iargs_grad) + self.args_mult.update(iargs_mult) + self.auxs.update(iauxs) + self.encoder, self.internals = self.make_encoder( + self.data, dims, sparseness_penalty, ft_dropout, internal_act, output_act) + self.decoder = self.make_decoder( + self.encoder, dims, sparseness_penalty, ft_dropout, internal_act, input_act) + if input_act == 'softmax': + self.loss = self.decoder + else: + self.loss = mx.symbol.LinearRegressionOutput(data=self.decoder, label=self.data) + + def make_stack(self, istack, data, num_input, num_hidden, sparseness_penalty=None, + idropout=None, odropout=None, encoder_act='relu', decoder_act='relu'): + x = data + if idropout: + x = mx.symbol.Dropout(data=x, p=idropout) + x = mx.symbol.FullyConnected(name='encoder_%d'%istack, data=x, num_hidden=num_hidden) + if encoder_act: + x = mx.symbol.Activation(data=x, act_type=encoder_act) + if encoder_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_encoder_%d' % istack, penalty=sparseness_penalty) + if odropout: + x = mx.symbol.Dropout(data=x, p=odropout) + x = mx.symbol.FullyConnected(name='decoder_%d'%istack, data=x, num_hidden=num_input) + if decoder_act == 'softmax': + x = mx.symbol.Softmax(data=x, label=data, prob_label=True, act_type=decoder_act) + elif decoder_act: + x = mx.symbol.Activation(data=x, act_type=decoder_act) + if decoder_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_decoder_%d' % istack, penalty=sparseness_penalty) + x = mx.symbol.LinearRegressionOutput(data=x, label=data) + else: + x = mx.symbol.LinearRegressionOutput(data=x, label=data) + + args = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), + 'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), + 'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), + 'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} + args_grad = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), + 'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), + 'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), + 'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} + args_mult = {'encoder_%d_weight'%istack: 1.0, + 'encoder_%d_bias'%istack: 2.0, + 'decoder_%d_weight'%istack: 1.0, + 'decoder_%d_bias'%istack: 2.0,} + auxs = {} + if encoder_act == 'sigmoid' and sparseness_penalty: + auxs['sparse_encoder_%d_moving_avg' % istack] = mx.nd.ones(num_hidden, self.xpu) * 0.5 + if decoder_act == 'sigmoid' and sparseness_penalty: + auxs['sparse_decoder_%d_moving_avg' % istack] = mx.nd.ones(num_input, self.xpu) * 0.5 + init = mx.initializer.Uniform(0.07) + for k, v in args.items(): + init(mx.initializer.InitDesc(k), v) + + return x, args, args_grad, args_mult, auxs + + def make_encoder(self, data, dims, sparseness_penalty=None, dropout=None, internal_act='relu', + output_act=None): + x = data + internals = [] + N = len(dims) - 1 + for i in range(N): + x = mx.symbol.FullyConnected(name='encoder_%d'%i, data=x, num_hidden=dims[i+1]) + if internal_act and i < N-1: + x = mx.symbol.Activation(data=x, act_type=internal_act) + if internal_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty) + elif output_act and i == N-1: + x = mx.symbol.Activation(data=x, act_type=output_act) + if output_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty) + if dropout: + x = mx.symbol.Dropout(data=x, p=dropout) + internals.append(x) + return x, internals + + def make_decoder(self, feature, dims, sparseness_penalty=None, dropout=None, + internal_act='relu', input_act=None): + x = feature + N = len(dims) - 1 + for i in reversed(range(N)): + x = mx.symbol.FullyConnected(name='decoder_%d'%i, data=x, num_hidden=dims[i]) + if internal_act and i > 0: + x = mx.symbol.Activation(data=x, act_type=internal_act) + if internal_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty) + elif input_act and i == 0: + x = mx.symbol.Activation(data=x, act_type=input_act) + if input_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty) + if dropout and i > 0: + x = mx.symbol.Dropout(data=x, p=dropout) + return x + + def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, + lr_scheduler=None, print_every=1000): + def l2_norm(label, pred): + return np.mean(np.square(label-pred))/2.0 + solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, + lr_scheduler=lr_scheduler) + solver.set_metric(mx.metric.CustomMetric(l2_norm)) + solver.set_monitor(Monitor(print_every)) + data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, + last_batch_handle='roll_over') + for i in range(self.N): + if i == 0: + data_iter_i = data_iter + else: + X_i = list(model.extract_feature( + self.internals[i-1], self.args, self.auxs, data_iter, X.shape[0], + self.xpu).values())[0] + data_iter_i = mx.io.NDArrayIter({'data': X_i}, batch_size=batch_size, + last_batch_handle='roll_over') + logging.info('Pre-training layer %d...', i) + solver.solve(self.xpu, self.stacks[i], self.args, self.args_grad, self.auxs, + data_iter_i, 0, n_iter, {}, False) + + def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None, + print_every=1000): + def l2_norm(label, pred): + return np.mean(np.square(label-pred))/2.0 + solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, + lr_scheduler=lr_scheduler) + solver.set_metric(mx.metric.CustomMetric(l2_norm)) + solver.set_monitor(Monitor(print_every)) + data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, + last_batch_handle='roll_over') + logging.info('Fine tuning...') + solver.solve(self.xpu, self.loss, self.args, self.args_grad, self.auxs, data_iter, + 0, n_iter, {}, False) + + def eval(self, X): + batch_size = 100 + data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False, + last_batch_handle='pad') + Y = list(model.extract_feature( + self.loss, self.args, self.auxs, data_iter, X.shape[0], self.xpu).values())[0] + return np.mean(np.square(Y-X))/2.0 \ No newline at end of file diff --git a/example/deep-embedded-clustering/data.py b/example/deep-embedded-clustering/data.py new file mode 100644 index 000000000000..9fd472e6a8b1 --- /dev/null +++ b/example/deep-embedded-clustering/data.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring +from __future__ import print_function + +import os +import numpy as np +from sklearn.datasets import fetch_mldata + + +def get_mnist(): + """ Gets MNIST dataset """ + + np.random.seed(1234) # set seed for deterministic ordering + data_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + data_path = os.path.join(data_path, '../../data') + mnist = fetch_mldata('MNIST original', data_home=data_path) + p = np.random.permutation(mnist.data.shape[0]) + X = mnist.data[p].astype(np.float32)*0.02 + Y = mnist.target[p] + return X, Y + + + + diff --git a/example/deep-embedded-clustering/model.py b/example/deep-embedded-clustering/model.py new file mode 100644 index 000000000000..777634e3cf88 --- /dev/null +++ b/example/deep-embedded-clustering/model.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring +from __future__ import print_function + +import mxnet as mx +import numpy as np +try: + import cPickle as pickle +except ModuleNotFoundError: + import pickle + + +def extract_feature(sym, args, auxs, data_iter, N, xpu=mx.cpu()): + input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in data_iter.provide_data] + input_names = [k for k, shape in data_iter.provide_data] + args = dict(args, **dict(zip(input_names, input_buffs))) + exe = sym.bind(xpu, args=args, aux_states=auxs) + outputs = [[] for _ in exe.outputs] + output_buffs = None + + data_iter.hard_reset() + for batch in data_iter: + for data, buff in zip(batch.data, input_buffs): + data.copyto(buff) + exe.forward(is_train=False) + if output_buffs is None: + output_buffs = [mx.nd.empty(i.shape, ctx=mx.cpu()) for i in exe.outputs] + else: + for out, buff in zip(outputs, output_buffs): + out.append(buff.asnumpy()) + for out, buff in zip(exe.outputs, output_buffs): + out.copyto(buff) + for out, buff in zip(outputs, output_buffs): + out.append(buff.asnumpy()) + outputs = [np.concatenate(i, axis=0)[:N] for i in outputs] + return dict(zip(sym.list_outputs(), outputs)) + + +class MXModel(object): + def __init__(self, xpu=mx.cpu(), *args, **kwargs): + self.xpu = xpu + self.loss = None + self.args = {} + self.args_grad = {} + self.args_mult = {} + self.auxs = {} + self.setup(*args, **kwargs) + + def save(self, fname): + args_save = {key: v.asnumpy() for key, v in self.args.items()} + with open(fname, 'wb') as fout: + pickle.dump(args_save, fout) + + def load(self, fname): + with open(fname, 'rb') as fin: + args_save = pickle.load(fin) + for key, v in args_save.items(): + if key in self.args: + self.args[key][:] = v + + def setup(self, *args, **kwargs): + raise NotImplementedError("must override this") \ No newline at end of file diff --git a/example/deep-embedded-clustering/solver.py b/example/deep-embedded-clustering/solver.py new file mode 100644 index 000000000000..567c78eeb06c --- /dev/null +++ b/example/deep-embedded-clustering/solver.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring +from __future__ import print_function + +import logging + +import mxnet as mx +import numpy as np + + +class Monitor(object): + def __init__(self, interval, level=logging.DEBUG, stat=None): + self.interval = interval + self.level = level + if stat is None: + def mean_abs(x): + return np.fabs(x).mean() + self.stat = mean_abs + else: + self.stat = stat + + def forward_end(self, i, internals): + if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level): + for key in sorted(internals.keys()): + arr = internals[key] + logging.log(self.level, 'Iter:%d param:%s\t\tstat(%s):%s', + i, key, self.stat.__name__, str(self.stat(arr.asnumpy()))) + + def backward_end(self, i, weights, grads, metric=None): + if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level): + for key in sorted(grads.keys()): + arr = grads[key] + logging.log(self.level, 'Iter:%d param:%s\t\tstat(%s):%s\t\tgrad_stat:%s', + i, key, self.stat.__name__, + str(self.stat(weights[key].asnumpy())), str(self.stat(arr.asnumpy()))) + if i % self.interval == 0 and metric is not None: + logging.log(logging.INFO, 'Iter:%d metric:%f', i, metric.get()[1]) + metric.reset() + + +class Solver(object): + def __init__(self, optimizer, **kwargs): + if isinstance(optimizer, str): + self.optimizer = mx.optimizer.create(optimizer, **kwargs) + else: + self.optimizer = optimizer + self.updater = mx.optimizer.get_updater(self.optimizer) + self.monitor = None + self.metric = None + self.iter_end_callback = None + self.iter_start_callback = None + + def set_metric(self, metric): + self.metric = metric + + def set_monitor(self, monitor): + self.monitor = monitor + + def set_iter_end_callback(self, callback): + self.iter_end_callback = callback + + def set_iter_start_callback(self, callback): + self.iter_start_callback = callback + + def solve(self, xpu, sym, args, args_grad, auxs, + data_iter, begin_iter, end_iter, args_lrmult=None, debug=False): + if args_lrmult is None: + args_lrmult = dict() + input_desc = data_iter.provide_data + data_iter.provide_label + input_names = [k for k, shape in input_desc] + input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in input_desc] + args = dict(args, **dict(zip(input_names, input_buffs))) + + output_names = sym.list_outputs() + if debug: + sym_group = [] + for x in sym.get_internals(): + if x.name not in args: + if x.name not in output_names: + x = mx.symbol.BlockGrad(x, name=x.name) + sym_group.append(x) + sym = mx.symbol.Group(sym_group) + exe = sym.bind(xpu, args=args, args_grad=args_grad, aux_states=auxs) + + assert len(sym.list_arguments()) == len(exe.grad_arrays) + update_dict = { + name: nd for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd is not None + } + batch_size = input_buffs[0].shape[0] + self.optimizer.rescale_grad = 1.0/batch_size + self.optimizer.set_lr_mult(args_lrmult) + + output_dict = {} + output_buff = {} + internal_dict = dict(zip(input_names, input_buffs)) + for key, arr in zip(sym.list_outputs(), exe.outputs): + if key in output_names: + output_dict[key] = arr + output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) + else: + internal_dict[key] = arr + + data_iter.reset() + for i in range(begin_iter, end_iter): + if self.iter_start_callback is not None: + if self.iter_start_callback(i): + return + try: + batch = data_iter.next() + except StopIteration: + data_iter.reset() + batch = data_iter.next() + for data, buff in zip(batch.data+batch.label, input_buffs): + data.copyto(buff) + exe.forward(is_train=True) + if self.monitor is not None: + self.monitor.forward_end(i, internal_dict) + for key in output_dict: + output_dict[key].copyto(output_buff[key]) + + exe.backward() + for key, arr in update_dict.items(): + self.updater(key, arr, args[key]) + + if self.metric is not None: + self.metric.update([input_buffs[-1]], + [output_buff[output_names[0]]]) + + if self.monitor is not None: + self.monitor.backward_end(i, args, update_dict, self.metric) + + if self.iter_end_callback is not None: + if self.iter_end_callback(i): + return + exe.outputs[0].wait_to_read() \ No newline at end of file diff --git a/example/warpctc/README.md b/example/warpctc/README.md deleted file mode 100644 index 9ab56b336a5e..000000000000 --- a/example/warpctc/README.md +++ /dev/null @@ -1,108 +0,0 @@ -# Baidu Warp CTC with Mxnet - -Baidu-warpctc is a CTC implement by Baidu which support GPU. CTC can be used with LSTM to solve lable alignment problems in many areas such as OCR, speech recognition. - -## Install baidu warpctc - -``` - cd ~/ - git clone https://github.com/baidu-research/warp-ctc - cd warp-ctc - mkdir build - cd build - cmake .. - make - sudo make install -``` - -## Enable warpctc in mxnet - -``` - comment out following lines in make/config.mk - WARPCTC_PATH = $(HOME)/warp-ctc - MXNET_PLUGINS += plugin/warpctc/warpctc.mk - - rebuild mxnet by - make clean && make -j4 -``` - -## Run examples - -I implement two examples, one is just a toy example which can be used to prove ctc integration is right. The second is a OCR example with LSTM+CTC. You can run it by: - -``` - cd examples/warpctc - python lstm_ocr.py -``` - -Notes: -* Please modify ```contexts = [mx.context.gpu(0)]``` in this file according to your hardware. -* Please review the code ```'./font/Ubuntu-M.ttf'```. Copy your font to here font/yourfont.ttf. To get a free font from [here](http://font.ubuntu.com/). -* The checkpoint will be auto saved in each epoch. And then you can use this checkpoint to do a predict. - -The OCR example is constructed as follows: - -1. I generate 80x30 image for 4 digits captcha by an python captcha library -2. The 80x30 image is used as 80 input for lstm and every input is one column of image (a 30 dim vector) -3. The output layer use CTC loss - -Following code show detail construction of the net: - -``` - def lstm_unroll(num_lstm_layer, seq_len, - num_hidden, num_label): - param_cells = [] - last_states = [] - for i in range(num_lstm_layer): - param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), - i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), - h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), - h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) - state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), - h=mx.sym.Variable("l%d_init_h" % i)) - last_states.append(state) - assert(len(last_states) == num_lstm_layer) - data = mx.sym.Variable('data') - label = mx.sym.Variable('label') - - #every column of image is an input, there are seq_len inputs - wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) - hidden_all = [] - for seqidx in range(seq_len): - hidden = wordvec[seqidx] - for i in range(num_lstm_layer): - next_state = lstm(num_hidden, indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=seqidx, layeridx=i) - hidden = next_state.h - last_states[i] = next_state - hidden_all.append(hidden) - hidden_concat = mx.sym.Concat(*hidden_all, dim=0) - pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) - - # here we do NOT need to transpose label as other lstm examples do - label = mx.sym.Reshape(data=label, target_shape=(0,)) - #label should be int type, so use cast - label = mx.sym.Cast(data = label, dtype = 'int32') - sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) - return sm -``` - -## Support multi label length - -If you label length is smaller than or equal to b. You should provide labels with length b, and for those samples which label length is smaller than b, you should append 0 to label data to make it have length b. - -Here, 0 is reserved for blank label. - -## Do a predict - -Pelase run: - -``` -python ocr_predict.py -``` - -Notes: -* Change the code following the name of your params and json file. -* You have to do a ```make``` in amalgamation folder.(a libmxnet_predict.so will be created in lib folder.) diff --git a/example/warpctc/infer_ocr.py b/example/warpctc/infer_ocr.py deleted file mode 100644 index d469990ff937..000000000000 --- a/example/warpctc/infer_ocr.py +++ /dev/null @@ -1,118 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# coding=utf-8 -# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme -# pylint: disable=superfluous-parens, no-member, invalid-name -import sys - -sys.path.insert(0, "../../python") -from __future__ import print_function -import numpy as np -import mxnet as mx - -from lstm_model import LSTMInferenceModel - -import cv2, random -from captcha.image import ImageCaptcha - -BATCH_SIZE = 32 -SEQ_LENGTH = 80 - - -def ctc_label(p): - ret = [] - p1 = [0] + p - for i in range(len(p)): - c1 = p1[i] - c2 = p1[i + 1] - if c2 == 0 or c2 == c1: - continue - ret.append(c2) - return ret - - -def remove_blank(l): - ret = [] - for i in range(len(l)): - if l[i] == 0: - break - ret.append(l[i]) - return ret - - -def gen_rand(): - buf = "" - max_len = random.randint(3,4) - for i in range(max_len): - buf += str(random.randint(0,9)) - return buf - -if __name__ == '__main__': - num_hidden = 100 - num_lstm_layer = 2 - - num_epoch = 10 - learning_rate = 0.001 - momentum = 0.9 - num_label = 4 - - n_channel = 1 - contexts = [mx.context.gpu(0)] - _, arg_params, __ = mx.model.load_checkpoint('ocr', num_epoch) - - num = gen_rand() - print('Generated number: ' + num) - # change the fonts accordingly - captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf']) - img = captcha.generate(num) - img = np.fromstring(img.getvalue(), dtype='uint8') - img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) - img = cv2.resize(img, (80, 30)) - - img = img.transpose(1, 0) - - img = img.reshape((1, 80 * 30)) - img = np.multiply(img, 1 / 255.0) - - data_shape = [('data', (1, n_channel * 80 * 30))] - input_shapes = dict(data_shape) - - model = LSTMInferenceModel(num_lstm_layer, - SEQ_LENGTH, - num_hidden=num_hidden, - num_label=num_label, - arg_params=arg_params, - data_size = n_channel * 30 * 80, - ctx=contexts[0]) - - prob = model.forward(mx.nd.array(img)) - - p = [] - for k in range(SEQ_LENGTH): - p.append(np.argmax(prob[k])) - - p = ctc_label(p) - print('Predicted label: ' + str(p)) - - pred = '' - for c in p: - pred += str((int(c) - 1)) - - print('Predicted number: ' + pred) - - diff --git a/example/warpctc/lstm.py b/example/warpctc/lstm.py deleted file mode 100644 index 9e0e05c9011d..000000000000 --- a/example/warpctc/lstm.py +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint:skip-file -import sys -sys.path.insert(0, "../../python") -import mxnet as mx -import numpy as np -from collections import namedtuple -import time -import math -LSTMState = namedtuple("LSTMState", ["c", "h"]) -LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", - "h2h_weight", "h2h_bias"]) -LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", - "init_states", "last_states", - "seq_data", "seq_labels", "seq_outputs", - "param_blocks"]) - -def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): - """LSTM Cell symbol""" - i2h = mx.sym.FullyConnected(data=indata, - weight=param.i2h_weight, - bias=param.i2h_bias, - num_hidden=num_hidden * 4, - name="t%d_l%d_i2h" % (seqidx, layeridx)) - h2h = mx.sym.FullyConnected(data=prev_state.h, - weight=param.h2h_weight, - bias=param.h2h_bias, - num_hidden=num_hidden * 4, - name="t%d_l%d_h2h" % (seqidx, layeridx)) - gates = i2h + h2h - slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, - name="t%d_l%d_slice" % (seqidx, layeridx)) - in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") - in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") - forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") - out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") - next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) - next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") - return LSTMState(c=next_c, h=next_h) - - -def lstm_unroll(num_lstm_layer, seq_len, - num_hidden, num_label): - param_cells = [] - last_states = [] - for i in range(num_lstm_layer): - param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), - i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), - h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), - h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) - state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), - h=mx.sym.Variable("l%d_init_h" % i)) - last_states.append(state) - assert(len(last_states) == num_lstm_layer) - - # embeding layer - data = mx.sym.Variable('data') - label = mx.sym.Variable('label') - wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) - - hidden_all = [] - for seqidx in range(seq_len): - hidden = wordvec[seqidx] - for i in range(num_lstm_layer): - next_state = lstm(num_hidden, indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=seqidx, layeridx=i) - hidden = next_state.h - last_states[i] = next_state - hidden_all.append(hidden) - - hidden_concat = mx.sym.Concat(*hidden_all, dim=0) - pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) - - label = mx.sym.Reshape(data=label, shape=(-1,)) - label = mx.sym.Cast(data = label, dtype = 'int32') - sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) - return sm - - -def lstm_inference_symbol(num_lstm_layer, seq_len, num_hidden, num_label): - param_cells = [] - last_states = [] - for i in range(num_lstm_layer): - param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), - i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), - h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), - h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) - state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), - h=mx.sym.Variable("l%d_init_h" % i)) - last_states.append(state) - assert (len(last_states) == num_lstm_layer) - - # embeding layer - data = mx.sym.Variable('data') - wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) - - hidden_all = [] - for seqidx in range(seq_len): - hidden = wordvec[seqidx] - for i in range(num_lstm_layer): - next_state = lstm(num_hidden, indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=seqidx, layeridx=i) - hidden = next_state.h - last_states[i] = next_state - hidden_all.append(hidden) - - hidden_concat = mx.sym.Concat(*hidden_all, dim=0) - fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) - sm = mx.sym.SoftmaxOutput(data=fc, name='softmax') - - output = [sm] - for state in last_states: - output.append(state.c) - output.append(state.h) - return mx.sym.Group(output) diff --git a/example/warpctc/lstm_model.py b/example/warpctc/lstm_model.py deleted file mode 100644 index d359f1ae5a90..000000000000 --- a/example/warpctc/lstm_model.py +++ /dev/null @@ -1,71 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme -# pylint: disable=superfluous-parens, no-member, invalid-name -import sys -sys.path.insert(0, "../../python") -import numpy as np -import mxnet as mx - -from lstm import LSTMState, LSTMParam, lstm, lstm_inference_symbol - - -class LSTMInferenceModel(object): - def __init__(self, - num_lstm_layer, - seq_len, - num_hidden, - num_label, - arg_params, - data_size, - ctx=mx.cpu()): - self.sym = lstm_inference_symbol(num_lstm_layer, - seq_len, - num_hidden, - num_label) - - batch_size = 1 - init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] - init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] - data_shape = [("data", (batch_size, data_size))] - input_shapes = dict(init_c + init_h + data_shape) - self.executor = self.sym.simple_bind(ctx=ctx, **input_shapes) - - for key in self.executor.arg_dict.keys(): - if key in arg_params: - arg_params[key].copyto(self.executor.arg_dict[key]) - - state_name = [] - for i in range(num_lstm_layer): - state_name.append("l%d_init_c" % i) - state_name.append("l%d_init_h" % i) - - self.states_dict = dict(zip(state_name, self.executor.outputs[1:])) - self.input_arr = mx.nd.zeros(data_shape[0][1]) - - def forward(self, input_data, new_seq=False): - if new_seq == True: - for key in self.states_dict.keys(): - self.executor.arg_dict[key][:] = 0. - input_data.copyto(self.executor.arg_dict["data"]) - self.executor.forward() - for key in self.states_dict.keys(): - self.states_dict[key].copyto(self.executor.arg_dict[key]) - prob = self.executor.outputs[0].asnumpy() - return prob diff --git a/example/warpctc/lstm_ocr.py b/example/warpctc/lstm_ocr.py deleted file mode 100644 index 9dd39efb4962..000000000000 --- a/example/warpctc/lstm_ocr.py +++ /dev/null @@ -1,229 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme -# pylint: disable=superfluous-parens, no-member, invalid-name -from __future__ import print_function -import sys, random -sys.path.insert(0, "../../python") -import numpy as np -import mxnet as mx - -from lstm import lstm_unroll - -from io import BytesIO -from captcha.image import ImageCaptcha -import cv2, random - -class SimpleBatch(object): - def __init__(self, data_names, data, label_names, label): - self.data = data - self.label = label - self.data_names = data_names - self.label_names = label_names - - self.pad = 0 - self.index = None # TODO: what is index? - - @property - def provide_data(self): - return [(n, x.shape) for n, x in zip(self.data_names, self.data)] - - @property - def provide_label(self): - return [(n, x.shape) for n, x in zip(self.label_names, self.label)] - -def gen_rand(): - buf = "" - max_len = random.randint(3,4) - for i in range(max_len): - buf += str(random.randint(0,9)) - return buf - -def get_label(buf): - ret = np.zeros(4) - for i in range(len(buf)): - ret[i] = 1 + int(buf[i]) - if len(buf) == 3: - ret[3] = 0 - return ret - -class OCRIter(mx.io.DataIter): - def __init__(self, count, batch_size, num_label, init_states): - super(OCRIter, self).__init__() - # you can get this font from http://font.ubuntu.com/ - self.captcha = ImageCaptcha(fonts=['./font/Ubuntu-M.ttf']) - self.batch_size = batch_size - self.count = count - self.num_label = num_label - self.init_states = init_states - self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] - self.provide_data = [('data', (batch_size, 80, 30))] + init_states - self.provide_label = [('label', (self.batch_size, 4))] - - def __iter__(self): - print('iter') - init_state_names = [x[0] for x in self.init_states] - for k in range(self.count): - data = [] - label = [] - for i in range(self.batch_size): - num = gen_rand() - img = self.captcha.generate(num) - img = np.fromstring(img.getvalue(), dtype='uint8') - img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) - img = cv2.resize(img, (80, 30)) - img = img.transpose(1, 0) - img = img.reshape((80, 30)) - img = np.multiply(img, 1/255.0) - data.append(img) - label.append(get_label(num)) - - data_all = [mx.nd.array(data)] + self.init_state_arrays - label_all = [mx.nd.array(label)] - data_names = ['data'] + init_state_names - label_names = ['label'] - - - data_batch = SimpleBatch(data_names, data_all, label_names, label_all) - yield data_batch - - def reset(self): - pass - -BATCH_SIZE = 32 -SEQ_LENGTH = 80 - -def ctc_label(p): - ret = [] - p1 = [0] + p - for i in range(len(p)): - c1 = p1[i] - c2 = p1[i+1] - if c2 == 0 or c2 == c1: - continue - ret.append(c2) - return ret - -def remove_blank(l): - ret = [] - for i in range(len(l)): - if l[i] == 0: - break - ret.append(l[i]) - return ret - -def Accuracy(label, pred): - global BATCH_SIZE - global SEQ_LENGTH - hit = 0. - total = 0. - for i in range(BATCH_SIZE): - l = remove_blank(label[i]) - p = [] - for k in range(SEQ_LENGTH): - p.append(np.argmax(pred[k * BATCH_SIZE + i])) - p = ctc_label(p) - if len(p) == len(l): - match = True - for k in range(len(p)): - if p[k] != int(l[k]): - match = False - break - if match: - hit += 1.0 - total += 1.0 - return hit / total - -def LCS(p,l): - # Dynamic Programming Finding LCS - if len(p) == 0: - return 0 - P = np.array(list(p)).reshape((1, len(p))) - L = np.array(list(l)).reshape((len(l), 1)) - M = np.int32(P == L) - for i in range(M.shape[0]): - for j in range(M.shape[1]): - up = 0 if i == 0 else M[i-1,j] - left = 0 if j == 0 else M[i,j-1] - M[i,j] = max(up, left, M[i,j] if (i == 0 or j == 0) else M[i,j] + M[i-1,j-1]) - return M.max() - - -def Accuracy_LCS(label, pred): - global BATCH_SIZE - global SEQ_LENGTH - hit = 0. - total = 0. - for i in range(BATCH_SIZE): - l = remove_blank(label[i]) - p = [] - for k in range(SEQ_LENGTH): - p.append(np.argmax(pred[k * BATCH_SIZE + i])) - p = ctc_label(p) - hit += LCS(p,l) * 1.0 / len(l) - total += 1.0 - return hit / total - -if __name__ == '__main__': - num_hidden = 100 - num_lstm_layer = 2 - - num_epoch = 10 - learning_rate = 0.001 - momentum = 0.9 - num_label = 4 - - contexts = [mx.context.gpu(0)] - - def sym_gen(seq_len): - return lstm_unroll(num_lstm_layer, seq_len, - num_hidden=num_hidden, - num_label = num_label) - - init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] - init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] - init_states = init_c + init_h - - data_train = OCRIter(10000, BATCH_SIZE, num_label, init_states) - data_val = OCRIter(1000, BATCH_SIZE, num_label, init_states) - - symbol = sym_gen(SEQ_LENGTH) - - model = mx.model.FeedForward(ctx=contexts, - symbol=symbol, - num_epoch=num_epoch, - learning_rate=learning_rate, - momentum=momentum, - wd=0.00001, - initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) - - import logging - head = '%(asctime)-15s %(message)s' - logging.basicConfig(level=logging.DEBUG, format=head) - - print('begin fit') - - prefix = 'ocr' - model.fit(X=data_train, eval_data=data_val, - eval_metric = mx.metric.np(Accuracy), - # Use the following eval_metric if your num_label >= 10, or varies in a wide range - # eval_metric = mx.metric.np(Accuracy_LCS), - batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50), - epoch_end_callback = mx.callback.do_checkpoint(prefix, 1)) - - model.save(prefix) diff --git a/example/warpctc/ocr_predict.py b/example/warpctc/ocr_predict.py deleted file mode 100644 index 3096a664a20f..000000000000 --- a/example/warpctc/ocr_predict.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python2.7 - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# coding=utf-8 -from __future__ import print_function -import sys, os -curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) -sys.path.append("../../amalgamation/python/") -sys.path.append("../../python/") - -from mxnet_predict import Predictor -import mxnet as mx - -import numpy as np -import cv2 -import os - -class lstm_ocr_model(object): - # Keep Zero index for blank. (CTC request it) - CONST_CHAR='0123456789' - def __init__(self, path_of_json, path_of_params): - super(lstm_ocr_model, self).__init__() - self.path_of_json = path_of_json - self.path_of_params = path_of_params - self.predictor = None - self.__init_ocr() - - def __init_ocr(self): - num_label = 4 # Set your max length of label, add one more for blank - batch_size = 1 - - num_hidden = 100 - num_lstm_layer = 2 - init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] - init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] - init_states = init_c + init_h - - init_state_arrays = np.zeros((batch_size, num_hidden), dtype="float32") - self.init_state_dict={} - for x in init_states: - self.init_state_dict[x[0]] = init_state_arrays - - all_shapes = [('data', (batch_size, 80 * 30))] + init_states + [('label', (batch_size, num_label))] - all_shapes_dict = {} - for _shape in all_shapes: - all_shapes_dict[_shape[0]] = _shape[1] - self.predictor = Predictor(open(self.path_of_json).read(), - open(self.path_of_params).read(), - all_shapes_dict) - - def forward_ocr(self, img): - img = cv2.resize(img, (80, 30)) - img = img.transpose(1, 0) - img = img.reshape((80 * 30)) - img = np.multiply(img, 1/255.0) - self.predictor.forward(data=img, **self.init_state_dict) - prob = self.predictor.get_output(0) - label_list = [] - for p in prob: - max_index = np.argsort(p)[::-1][0] - label_list.append(max_index) - return self.__get_string(label_list) - - def __get_string(self, label_list): - # Do CTC label rule - # CTC cannot emit a repeated symbol on consecutive timesteps - ret = [] - label_list2 = [0] + list(label_list) - for i in range(len(label_list)): - c1 = label_list2[i] - c2 = label_list2[i+1] - if c2 == 0 or c2 == c1: - continue - ret.append(c2) - # change to ascii - s = '' - for l in ret: - if l > 0 and l < (len(lstm_ocr_model.CONST_CHAR)+1): - c = lstm_ocr_model.CONST_CHAR[l-1] - else: - c = '' - s += c - return s - -if __name__ == '__main__': - _lstm_ocr_model = lstm_ocr_model('ocr-symbol.json', 'ocr-0010.params') - img = cv2.imread('sample.jpg', 0) - _str = _lstm_ocr_model.forward_ocr(img) - print('Result: ', _str) - diff --git a/example/warpctc/toy_ctc.py b/example/warpctc/toy_ctc.py deleted file mode 100644 index c7b0ccc3df3d..000000000000 --- a/example/warpctc/toy_ctc.py +++ /dev/null @@ -1,181 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme -# pylint: disable=superfluous-parens, no-member, invalid-name -from __future__ import print_function -import sys -sys.path.insert(0, "../../python") -import numpy as np -import mxnet as mx -import random -from lstm import lstm_unroll - -class SimpleBatch(object): - def __init__(self, data_names, data, label_names, label): - self.data = data - self.label = label - self.data_names = data_names - self.label_names = label_names - - self.pad = 0 - self.index = None # TODO: what is index? - - @property - def provide_data(self): - return [(n, x.shape) for n, x in zip(self.data_names, self.data)] - - @property - def provide_label(self): - return [(n, x.shape) for n, x in zip(self.label_names, self.label)] - -def gen_feature(n): - ret = np.zeros(10) - ret[n] = 1 - return ret - -def gen_rand(): - num = random.randint(0, 9999) - buf = str(num) - while len(buf) < 4: - buf = "0" + buf - ret = [] - for i in range(80): - c = int(buf[i // 20]) - ret.append(gen_feature(c)) - return buf, ret - -def get_label(buf): - ret = np.zeros(4) - for i in range(4): - ret[i] = 1 + int(buf[i]) - return ret - -class DataIter(mx.io.DataIter): - def __init__(self, count, batch_size, num_label, init_states): - super(DataIter, self).__init__() - self.batch_size = batch_size - self.count = count - self.num_label = num_label - self.init_states = init_states - self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] - self.provide_data = [('data', (batch_size, 80, 10))] + init_states - self.provide_label = [('label', (self.batch_size, 4))] - - def __iter__(self): - init_state_names = [x[0] for x in self.init_states] - for k in range(self.count): - data = [] - label = [] - for i in range(self.batch_size): - num, img = gen_rand() - data.append(img) - label.append(get_label(num)) - - data_all = [mx.nd.array(data)] + self.init_state_arrays - label_all = [mx.nd.array(label)] - data_names = ['data'] + init_state_names - label_names = ['label'] - - - data_batch = SimpleBatch(data_names, data_all, label_names, label_all) - yield data_batch - - def reset(self): - pass - -BATCH_SIZE = 32 -SEQ_LENGTH = 80 - -def ctc_label(p): - ret = [] - p1 = [0] + p - for i in range(len(p)): - c1 = p1[i] - c2 = p1[i+1] - if c2 == 0 or c2 == c1: - continue - ret.append(c2) - return ret - - -def Accuracy(label, pred): - global BATCH_SIZE - global SEQ_LENGTH - hit = 0. - total = 0. - for i in range(BATCH_SIZE): - l = label[i] - p = [] - for k in range(SEQ_LENGTH): - p.append(np.argmax(pred[k * BATCH_SIZE + i])) - p = ctc_label(p) - if len(p) == len(l): - match = True - for k in range(len(p)): - if p[k] != int(l[k]): - match = False - break - if match: - hit += 1.0 - total += 1.0 - return hit / total - -if __name__ == '__main__': - num_hidden = 100 - num_lstm_layer = 1 - - num_epoch = 10 - learning_rate = 0.001 - momentum = 0.9 - num_label = 4 - - contexts = [mx.context.gpu(0)] - - def sym_gen(seq_len): - return lstm_unroll(num_lstm_layer, seq_len, - num_hidden=num_hidden, - num_label = num_label) - - init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] - init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] - init_states = init_c + init_h - - data_train = DataIter(100000, BATCH_SIZE, num_label, init_states) - data_val = DataIter(1000, BATCH_SIZE, num_label, init_states) - - symbol = sym_gen(SEQ_LENGTH) - - model = mx.model.FeedForward(ctx=contexts, - symbol=symbol, - num_epoch=num_epoch, - learning_rate=learning_rate, - momentum=momentum, - wd=0.00001, - initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) - - import logging - head = '%(asctime)-15s %(message)s' - logging.basicConfig(level=logging.DEBUG, format=head) - - print('begin fit') - - model.fit(X=data_train, eval_data=data_val, - eval_metric = mx.metric.np(Accuracy), - batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),) - - model.save("ocr") diff --git a/mshadow b/mshadow index 3d87ed2a4b47..16ac8cdfd1c5 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 3d87ed2a4b47ef749c616f208cee45d920fb6e6e +Subproject commit 16ac8cdfd1c5fbdbee781ec29aaa4478e6eb0ae0 diff --git a/prepare_mkl.sh b/prepare_mkl.sh index a9d659f71591..657722eb938c 100755 --- a/prepare_mkl.sh +++ b/prepare_mkl.sh @@ -84,7 +84,7 @@ elif [ $PLATFORM == "Linux" ]; then fi ARCHIVE_BASENAME=mklml_${INFIX}_2018.0.1.${VERSION_MATCH}.tgz MKL_CONTENT_DIR=`echo $ARCHIVE_BASENAME | rev | cut -d "." -f 2- | rev` -MKLURL="https://github.com/01org/mkl-dnn/releases/download/v0.11/$ARCHIVE_BASENAME" +MKLURL="https://github.com/01org/mkl-dnn/releases/download/v0.12/$ARCHIVE_BASENAME" # there are diffrent MKL lib to be used for GCC and for ICC reg='^[0-9]+$' VERSION_LINE=`GetVersionName $MKLROOT` diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 0d49def8cf65..fd75e4b7b7aa 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -22,6 +22,7 @@ import copy import warnings +import re from .. import symbol, ndarray, initializer from ..symbol import Symbol @@ -227,13 +228,38 @@ def params(self): children's parameters).""" return self._params - def collect_params(self): + def collect_params(self, select=None): """Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its - children's Parameters.""" + children's Parameters(default), also can returns the select :py:class:`ParameterDict` + which match some given regular expressions. + + For example, collect the specified parameter in ['conv1_weight', 'conv1_bias', 'fc_weight', + 'fc_bias']:: + + model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias') + + or collect all paramters which their name ends with 'weight' or 'bias', this can be done + using regular expressions:: + + model.collect_params('.*weight|.*bias') + + Parameters + ---------- + select : str + regular expressions + + Returns + ------- + The selected :py:class:`ParameterDict` + """ ret = ParameterDict(self._params.prefix) - ret.update(self.params) + if not select: + ret.update(self.params) + else: + pattern = re.compile(select) + ret.update({name:value for name, value in self.params.items() if pattern.match(name)}) for cld in self._children: - ret.update(cld.collect_params()) + ret.update(cld.collect_params(select=select)) return ret def save_params(self, filename): @@ -261,7 +287,6 @@ def load_params(self, filename, ctx, allow_missing=False, self.collect_params().load(filename, ctx, allow_missing, ignore_extra, self.prefix) - def register_child(self, block): """Registers block as a child of self. :py:class:`Block` s assigned to self as attributes will be registered automatically.""" diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index 141a33806ad8..9bff11765908 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -21,6 +21,7 @@ from __future__ import absolute_import import traceback +import warnings from array import array from threading import Lock @@ -47,6 +48,7 @@ class PythonOp(object): def __init__(self, need_top_grad=True): self.info_ = None self.need_top_grad_ = need_top_grad + warnings.warn('PythonOp has been deprecated. Please use CustomOp') def __call__(self, *args, **kwargs): return self.get_symbol(*args, **kwargs) @@ -152,6 +154,7 @@ class NumpyOp(PythonOp): """ def __init__(self, need_top_grad=True): super(NumpyOp, self).__init__(need_top_grad) + warnings.warn('NumpyOp has been deprecated. Please use CustomOp') def get_symbol(self, *args, **kwargs): fb_functype = CFUNCTYPE(None, c_int, POINTER(POINTER(mx_float)), POINTER(c_int), @@ -254,6 +257,7 @@ class NDArrayOp(PythonOp): """ def __init__(self, need_top_grad=True): super(NDArrayOp, self).__init__(need_top_grad) + warnings.warn('NDArrayOp has been deprecated. Please use CustomOp') def get_symbol(self, *args, **kwargs): fb_functype = CFUNCTYPE(c_bool, c_int, POINTER(c_void_p), POINTER(c_int), c_void_p) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index feff87e0baab..4285aecef1f5 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -25,7 +25,8 @@ from .base import py_str from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs) from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, - mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update) + mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, + signsgd_update, signum_update) from .ndarray import _internal from .ndarray import op from .ndarray import sparse @@ -534,6 +535,67 @@ def update_multi_precision(self, index, weight, grad, state): self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision) +@register +class Signum(Optimizer): + """The Signum optimizer that takes the sign of gradient or momentum. + + The optimizer updates the weight by: + + rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight + state = momentum * state + (1-momentum)*rescaled_grad + weight = (1 - lr * wd_lh) * weight - lr * sign(state) + + See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf + + For details of the update algorithm see + :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`. + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + Parameters + ---------- + momentum : float, optional + The momentum value. + wd_lh : float, optional + The amount of decoupled weight decay regularization, see details in the original paper at:\ + https://arxiv.org/abs/1711.05101 + """ + def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs): + super(Signum, self).__init__(learning_rate=learning_rate, **kwargs) + self.momentum = momentum + self.wd_lh = wd_lh + + def create_state(self, index, weight): + momentum = None + if self.momentum != 0.0: + momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) + return momentum + + def _update_impl(self, index, weight, grad, state): + assert(isinstance(weight, NDArray)) + assert(isinstance(grad, NDArray)) + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = self.momentum + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + if self.wd_lh: + kwargs['wd_lh'] = self.wd_lh + + if state is not None: + signum_update(weight, grad, state, out=weight, + lr=lr, wd=wd, **kwargs) + else: + signsgd_update(weight, grad, out=weight, + lr=lr, wd=wd, **kwargs) + + def update(self, index, weight, grad, state): + self._update_impl(index, weight, grad, state) @register class FTML(Optimizer): @@ -702,8 +764,7 @@ def update(self, index, weight, grad, state): if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), - shape=weight.shape, - ctx=weight.context) + weight.shape, weight.context) @register # pylint: disable=invalid-name diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu index edb3d4e26e93..66969fe89c49 100644 --- a/src/common/random_generator.cu +++ b/src/common/random_generator.cu @@ -45,13 +45,13 @@ __global__ void rand_generator_seed_kernel(curandStatePhilox4_32_10_t *states_, } template<> -void RandGenerator::Seed(Stream *s, uint32_t seed) { +void RandGenerator::Seed(mshadow::Stream *s, uint32_t seed) { using namespace mshadow::cuda; int ngrid = std::min(kMaxGridNum, (RandGenerator::kNumRandomStates + kBaseThreadNum - 1) / kBaseThreadNum); rand_generator_seed_kernel - <<::GetStream(s)>>>( + <<::GetStream(s)>>>( states_, RandGenerator::kNumRandomStates, seed); diff --git a/src/common/random_generator.h b/src/common/random_generator.h index 01cbbd166f25..5d78b616e534 100644 --- a/src/common/random_generator.h +++ b/src/common/random_generator.h @@ -34,8 +34,6 @@ #include "../common/cuda_utils.h" #endif // MXNET_USE_CUDA -using namespace mshadow; - namespace mxnet { namespace common { namespace random { @@ -90,7 +88,7 @@ class RandGenerator { delete[] inst->states_; } - MSHADOW_XINLINE void Seed(Stream *, uint32_t seed) { + MSHADOW_XINLINE void Seed(mshadow::Stream *, uint32_t seed) { for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i); } @@ -161,7 +159,7 @@ class RandGenerator { CUDA_CALL(cudaFree(inst->states_)); } - void Seed(Stream *s, uint32_t seed); + void Seed(mshadow::Stream *s, uint32_t seed); private: curandStatePhilox4_32_10_t *states_; diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 33b7dd5fe5a8..c2564db0f079 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter { } }; + struct SGDKernel { template MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, @@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter { } }; + struct SGDMomKernel { template MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, @@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs, } } + +// Implementation for signSGD and Signum + +struct SignSGDParam : public dmlc::Parameter { + float lr; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(SignSGDParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + + +struct SignSGDKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, + const DType* grad_data, const DType param_clip_gradient, + const DType param_lr, const DType param_wd, const DType param_rescale_grad, + const OpReqType req) { + + // param_clip_gradient has no effect for SignSGD + KERNEL_ASSIGN(out_data[i], req, + (1.f-param_lr*param_wd)*weight_data[i] + - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0))); + } +}; + +template +inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const SignSGDParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_, + grad.dptr_, static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad), req[0]); + }); +} + + +struct SignumParam : public dmlc::Parameter { + float lr; + float momentum; + float wd; + float rescale_grad; + float clip_gradient; + float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter + DMLC_DECLARE_PARAMETER(SignumParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(wd_lh) + .set_default(0.0f) + .describe("The amount of weight decay that does not go into gradient/momentum calculations" + "otherwise do weight decay algorithmically only."); + } +}; + +struct SignumKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, + const DType* grad_data, const DType param_clip_gradient, const DType param_momentum, + const DType param_lr, const DType param_wd, const DType param_rescale_grad, + const DType param_wd_lh, const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + - (1-param_momentum)*param_wd*weight_data[i] + - (1-param_momentum) + *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient); + } else { + mom_data[i] = param_momentum*mom_data[i] + - (1-param_momentum)*param_wd*weight_data[i] + - (1-param_momentum)*param_rescale_grad*grad_data[i]; + } + KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i] + + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0))); + } +}; + +template +inline void SignumUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + SignumParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_, + grad.dptr_, static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad), static_cast(param.wd_lh), req[0]); + }); +} + + + } // namespace op } // namespace mxnet diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index dda809255dce..8760fe94a526 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -36,6 +36,67 @@ DMLC_REGISTER_PARAMETER(AdamParam); DMLC_REGISTER_PARAMETER(RMSPropParam); DMLC_REGISTER_PARAMETER(RMSPropAlexParam); DMLC_REGISTER_PARAMETER(FtrlParam); +DMLC_REGISTER_PARAMETER(SignSGDParam); +DMLC_REGISTER_PARAMETER(SignumParam); + +NNVM_REGISTER_OP(signsgd_update) +.describe(R"code(Update function for SignSGD optimizer. +.. math:: + + g_t = \nabla J(W_{t-1})\\ + W_t = W_{t-1} - \eta_t \text{sign}(g_t)} + +It updates the weights using:: + + weight = weight - learning_rate * sign(gradient) + +.. note:: + - sparse ndarray not supported for this optimizer yet. +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", SignSGDUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_arguments(SignSGDParam::__FIELDS__()); + + +NNVM_REGISTER_OP(signum_update) +.describe(R"code(SIGN momentUM (Signum) optimizer. + +.. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta m_{t-1} + (1 - \beta) g_t\\ + W_t = W_{t-1} - \eta_t \text{sign}(m_t)} + +It updates the weights using:: + state = momentum * state + (1-momentum) * gradient + weight = weight - learning_rate * sign(state) + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +.. note:: + - sparse ndarray not supported for this optimizer yet. +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.set_attr("FCompute", SignumUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_arguments(SignumParam::__FIELDS__()); + template<> void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 9512e92a80ec..891f24fe7935 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -94,6 +94,13 @@ void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param, }); } + +NNVM_REGISTER_OP(signsgd_update) +.set_attr("FCompute", SignSGDUpdate); + +NNVM_REGISTER_OP(signum_update) +.set_attr("FCompute", SignumUpdate); + NNVM_REGISTER_OP(sgd_update) .set_attr("FCompute", SGDUpdate) .set_attr("FComputeEx", SGDUpdateEx); diff --git a/tests/ci_build/Dockerfile.build_cuda b/tests/ci_build/Dockerfile.build_cuda index 9084cca82b6f..d659a1a8adea 100644 --- a/tests/ci_build/Dockerfile.build_cuda +++ b/tests/ci_build/Dockerfile.build_cuda @@ -20,7 +20,7 @@ COPY install/ubuntu_install_nvidia.sh /install/ RUN /install/ubuntu_install_nvidia.sh # Add MKLML libraries -RUN wget --no-check-certificate -O /tmp/mklml.tgz https://github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171227.tgz +RUN wget --no-check-certificate -O /tmp/mklml.tgz https://github.com/01org/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz RUN tar -zxvf /tmp/mklml.tgz && cp -rf mklml_*/* /usr/local/ && rm -rf mklml_* ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib diff --git a/tests/ci_build/Dockerfile.cpu_mklml b/tests/ci_build/Dockerfile.cpu_mklml index 39ce77c7d84d..9f235b461197 100644 --- a/tests/ci_build/Dockerfile.cpu_mklml +++ b/tests/ci_build/Dockerfile.cpu_mklml @@ -12,7 +12,7 @@ COPY install/ubuntu_install_perl.sh /install/ RUN /install/ubuntu_install_perl.sh # Add MKLML library, compatiable with Ubuntu16.04 -RUN wget --no-check-certificate -O /tmp/mklml.tgz https://github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171227.tgz +RUN wget --no-check-certificate -O /tmp/mklml.tgz https://github.com/01org/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz RUN tar -zxvf /tmp/mklml.tgz && cp -rf mklml_*/* /usr/local/ && rm -rf mklml_* ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib diff --git a/tests/ci_build/Dockerfile.gpu_mklml b/tests/ci_build/Dockerfile.gpu_mklml index 7b5715fcf68f..dcef7f03a0c4 100644 --- a/tests/ci_build/Dockerfile.gpu_mklml +++ b/tests/ci_build/Dockerfile.gpu_mklml @@ -12,7 +12,7 @@ COPY install/ubuntu_install_scala.sh /install/ RUN /install/ubuntu_install_scala.sh # Add MKLML library, compatible with Ubuntu16.04 -RUN wget --no-check-certificate -O /tmp/mklml.tgz https://github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171227.tgz +RUN wget --no-check-certificate -O /tmp/mklml.tgz https://github.com/01org/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz RUN tar -zxvf /tmp/mklml.tgz && cp -rf mklml_*/* /usr/local/ && rm -rf mklml_* ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index f2d001a7022d..57bf5c97c790 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -86,7 +86,18 @@ def __init__(self, **kwargs): assert 'numpy.float32' in lines[1] assert lines[2] == ')' - +def test_collect_paramters(): + net = nn.HybridSequential(prefix="test_") + with net.name_scope(): + net.add(nn.Conv2D(10, 3)) + net.add(nn.Dense(10, activation='relu')) + assert set(net.collect_params().keys()) == \ + set(['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias']) + assert set(net.collect_params('.*weight').keys()) == \ + set(['test_conv0_weight', 'test_dense0_weight']) + assert set(net.collect_params('test_conv0_bias|test_dense0_bias').keys()) == \ + set(['test_conv0_bias', 'test_dense0_bias']) + def test_basic(): model = nn.Sequential() model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False)) diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 174d577556dd..bb903f483b30 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -61,6 +61,7 @@ def check_single_kv_pair(kv, key): check_single_kv_pair(init_kv(), 3) check_single_kv_pair(init_kv_with_str(), 'a') +@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384") def test_row_sparse_pull(): kv = init_kv_with_str('row_sparse') kv.init('e', mx.nd.ones(shape).tostype('row_sparse')) diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index ae248b0d0bc7..2d22391879ce 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -524,6 +524,87 @@ def test_adam(): compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse') + +# Signum +class PySignum(mx.optimizer.Optimizer): + """The python reference of Signum optimizer. + + The optimizer updates the weight by: + + rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight + state = momentum * state + (1-momentum)*rescaled_grad + weight = (1 - lr * wd_lh) * weight - lr * sign(state) + + See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf + + For details of the update algorithm see + :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`. + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + Parameters + ---------- + momentum : float, optional + The momentum value. + wd_lh : float, optitional + The amount of decoupled weight decay regularization. + """ + def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh = 0.0, **kwargs): + super(PySignum, self).__init__(learning_rate = learning_rate, **kwargs) + self.momentum = momentum + self.wd_lh = wd_lh + + def create_state(self, index, weight): + momentum = None + if self.momentum != 0.0: + momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) + return momentum + + def update(self, index, weight, grad, state): + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + + if state is not None: + mom = state + if self.clip_gradient is not None: + mom[:] = (self.momentum*mom - (1-self.momentum)*(wd*weight + + mx.nd.clip(grad*self.rescale_grad, -self.clip_gradient, self.clip_gradient))) + else: + mom[:] = self.momentum*mom - (1-self.momentum)*wd*weight - (1-self.momentum)*self.rescale_grad*grad + weight[:] = (1 - lr*self.wd_lh)*weight + lr*mx.nd.sign(mom) + else: + weight[:] = (1 - lr*(wd+self.wd_lh))*weight - lr*mx.nd.sign(grad) + +def test_signum(): + mx.random.seed(0) + opt1 = PySignum + opt2 = mx.optimizer.Signum + shape = (3, 4, 5) + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + wd_lh_options = [{}, {'wd_lh': 0.015}, {'wd_lh': 0.0}] + mom_options = [{}, {'momentum': 0.9}] + lr_options = [{'learning_rate': 0.05},{'learning_rate': 0.01}] + for dtype in [np.float32, np.float64]: + for cg_option in cg_options: + for rg_option in rg_options: + for wd_option in wd_options: + for mp_option in wd_lh_options: + for lr_option in lr_options: + for mom_option in mom_options: + kwarg = {} + kwarg.update(cg_option) + kwarg.update(rg_option) + kwarg.update(wd_option) + kwarg.update(mp_option) + kwarg.update(lr_option) + kwarg.update(mom_option) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + + # RMSProp class PyRMSProp(mx.optimizer.Optimizer): """RMSProp optimizer of Tieleman & Hinton, 2012,