diff --git a/README.md b/README.md index eec1ead..7455e00 100644 --- a/README.md +++ b/README.md @@ -1,61 +1,139 @@ # rnn-benchmarks -All benchmarks are reported using a Nvidia GeForce GTX TITAN X GPU. +All benchmarks are reported for a host with the following specifications : + * NVIDIA GeForce GTX TITAN X GPU + * Intel(R) Xeon(R) CPU E5-2630L v3 @ 1.80GHz + * CUDA 7.5, cudnnv5 These benchmarks compare the running time of various recurrent neural networks on different deep-learning libraries. -The networks (RNN or LSTM) take as input a 3D Tensor (batch_size x seq_length x input_size) and output the last hidden state, compute a MSE loss and backpropagate the errors through the network. Input layer size is always set to 100, and sequence length to 30. +The networks (RNN or LSTM) take as input a 3D Tensor `batch_size x seq_length x hidden_size` +and output the last hidden state, compute a MSE loss, backpropagate the errors through the network and do a simple update of the parameters (`params = params - lr*gradParams`). +The sequence length is always set to `30`. +The `hidden_size` specifies the size of the output and input layer of the networks. -The code of the scripts I ran is available. The implementations of each model on the different libraries each use the fastest implementations I was able to find. If you are aware of faster implementations, please let me know. I've reported results on Theano and Torch so far, but I will try to include many more libraries in the future. +The code of the scripts we ran are available. +The implementations of each model on the different libraries each use +the fastest implementations we were able to find. +If you are aware of faster implementations, please let me know. +I've reported results on Theano, Torch and TensorFlow so far, but we will try to include many more libraries in the future (including cudnn very soon). -The reported time is the average time needed to run a training example (and not a training batch), so the smaller the better. +The reported `Train` time is the average time needed to run (forward, backward, and update) a training example (and not a training batch), so the smaller the better. +We also report `Compile` time, which includes symbolic graph optimizations (Theano and TensorFlow compilation), as well as a forward and backward pass (to allocate memory). +While the compilation time isn't really a factor in production, it does increase debugging time, which is why we report it here. -### RNN +## LSTM -#### Hidden layer size 100 - Batch size 20 +This LSTM implementation used for these benchmarks does not use peephole connections between cell and gates. -| Library | Time (µs) | Forward only (µs) | -| ------------- | ------------- | ------------- | -| Theano | 253.9 | 87.82 | -| Torch | 315.4 | 121.8 | +### Batch Size 32 +#### Hidden Size 128 -#### Hidden layer size 500 - Batch size 20 +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 7.46 | 289.6 | 99.1 | +| Torch | 0.03 | 434.4 | 99.9 | +| TensorFlow | 3.91 | 820.0 | 266.7 | -| Library | Time (µs) | Forward only (µs) | -| ------------- | ------------- | ------------- | -| Torch | 376.0 | 143.1 | -| Theano | 498.4 | 182.9 | +#### Hidden Size 512 -#### Hidden layer size 1000 - Batch size 20 +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 7.59 | 619.4 | 200.9 | +| Torch | 0.19 | 610.7 | 201.7 | +| TensorFlow | 3.97 | 886.9 | 324.9 | -| Library | Time (µs) | Forward only (µs) | -| ------------- | ------------- | ------------- | -| Torch | 637.4 | 230.2 | -| Theano | 758.8 | 326.3 | +#### Hidden Size 1024 +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 9.62 | 1013.5 | 324.1 | +| Torch | 0.69 | 1139.8 | 346.3 | +| TensorFlow | 3.81 | 1329.2 | 562.7 | -### LSTM -#### Hidden layer size 100 - Batch size 20 +### Batch Size 128 -| Library | Time (µs) | Forward only (µs) | -| ------------- | ------------- | ------------- | -| Theano (FastLSTM) | 587.7 | 215.1 | -| Theano (LSTM) | 725.3 | 237.5 | -| Torch (Element-Research FastLSTM) | 1017.4 | 367.3 | -| Torch (Element-Research LSTM) | 3549.5 | 1630.8 | +#### Hidden Size 128 +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 7.38 | 102.9 | 25.6 | +| Torch | 0.03 | 109.8 | 25.2 | +| TensorFlow | 3.68 | 188.6 | 65.0 | -#### Hidden layer size 500 - Batch size 20 -| Library | Time (µs) | Forward only (µs) | -| ------------- | ------------- | ------------- | -| Theano (FastLSTM) | 1045.4 | 342.7 | -| Torch (Element-Research FastLSTM) | 1106.5 | 425.2 | -| Theano (LSTM) | 2298.1 | 736.4 | -| Torch (Element-Research LSTM) | 4636.1 | 2923.9 | +#### Hidden Size 512 +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 7.50 | 256.0 | 62.8 | +| Torch | 0.20 | 214.3 | 51.4 | +| TensorFlow | 3.73 | 255.2 | 114.2 | -FastLSTM implementations (for both Torch and Theano) do not use peephole connections between cell and gates, and compute the input, forget and output gates, as well as the hidden state, in the same operation. +#### Hidden Size 1024 + +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 7.45 | 583.4 | 160.2 | +| Torch | 0.75 | 558.1 | 112.4 | +| TensorFlow | 3.84 | 592.2 | 238.1 | + + +## RNN + +This section benchmarks a simple RNN implementation. + +### Batch Size 32 + +#### Hidden Size 128 + +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 4.31 | 104.6 | 30.9 | +| Torch | 0.05 | 259.53 | 103.06 | +| TensorFlow | 1.64 | 278.4 | 111.5 | + +#### Hidden Size 512 + +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 4.36 | 275.2 | 102.2 | +| Torch | 0.05 | 288.2 | 114.6 | +| TensorFlow | 1.62 | 349.7 | 218.4 | + +#### Hidden Size 1024 + +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 4.44 | 443.8 | 179.5 | +| Torch | 0.09 | 381.4 | 118.8 | +| TensorFlow | 1.72 | 530.0 | 241.7 | + +### Batch Size 128 + +#### Hidden Size 128 + +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 4.48 | 45.4 | 13.7 | +| Torch | 0.08 | 67.7 | 32.7 | +| TensorFlow | 1.70 | 75.5 | 33.6 | + +#### Hidden Size 512 + +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 4.40 | 79.0 | 23.8 | +| Torch | 0.09 | 73.5 | 34.2 | +| TensorFlow | 1.63 | 125.6 | 86.8 | + +#### Hidden Size 1024 + +| Library | Compile (s) | Train (µs) | Forward only (µs) | +| ------------- | ------------- | ------------- | ------------- | +| Theano | 4.38 | 147.8 | 50.3 | +| Torch | 0.13 | 150.2 | 64.7 | +| TensorFlow | 1.70 | 222.5 | 137.8 | diff --git a/tensorflow/README.md b/tensorflow/README.md index ab22a20..c733960 100644 --- a/tensorflow/README.md +++ b/tensorflow/README.md @@ -1,46 +1,170 @@ #TensorFlow benchmarks -provided by Maarten Bosma. +Provided by Maarten Bosma. I used the build-in rnn libary. ``basic_lstm`` is the Tensorflow equivalent of FastLSTM. -To install TensorFlow follow [these instructions](https://www.tensorflow.org/versions/r0.7/get_started/os_setup.html#pip-installation). +These results are produced using TensorFlow 0.8, cuda 7.5, cudnnv5, turned off ondemand cpu governor [1], Intel(R) Xeon(R) CPU E5-2630L v3 @ 1.80GHz, Titan X: -## Results +To install TensorFlow from source: + * https://www.tensorflow.org/versions/r0.8/get_started/os_setup.html#installing-from-sources + * http://stackoverflow.com/questions/34239537/how-to-update-tensorflow-from-source + +## Fast LSTM -These results are produced using TensorFlow 0.7.1, cuda 7.5, cudnnv4, turned off ondemand cpu governor [1], Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz, Titan X: + + +### 30 x 32 x 128 + +``` +$ python rnn.py -n basic_lstm -b 32 -l 128 -s 30 +Setup : compile + forward/backward x 1 +--- 3.91482686996 seconds +Forward: +--- 32000 samples in 8.53500294685 seconds (3749.266427 samples/s, 0.0002667 s/sample) --- +Forward + Backward: +--- 32000 samples in 26.2391839027 seconds (1219.550125 samples/s, 0.0008200 s/sample) --- +``` + +### 30 x 32 x 512 + +``` +python rnn.py -n basic_lstm -b 32 -l 512 -s 30 +Setup : compile + forward/backward x 1 +--- 3.97159981728 seconds +Forward: +--- 32000 samples in 10.3965659142 seconds (3077.939414 samples/s, 0.0003249 s/sample) --- +Forward + Backward: +--- 32000 samples in 28.3808200359 seconds (1127.522036 samples/s, 0.0008869 s/sample) --- +``` + +### 30 x 32 x 1024 + + +``` +python rnn.py -n basic_lstm -b 32 -l 1024 -s 30 +Setup : compile + forward/backward x 1 +--- 3.81890392303 seconds +Forward: +--- 32000 samples in 18.0062820911 seconds (1777.157541 samples/s, 0.0005627 s/sample) --- +Forward + Backward: +--- 32000 samples in 42.533454895 seconds (752.348947 samples/s, 0.0013292 s/sample) --- +``` + + +### 30 x 128 x 128 + +``` +$ python rnn.py -n basic_lstm -b 128 -l 128 -s 30 +Setup : compile + forward/backward x 1 +--- 3.68258690834 seconds +Forward: +--- 128000 samples in 8.3175599575 seconds (15389.128621 samples/s, 0.0000650 s/sample) --- +Forward + Backward: +--- 128000 samples in 24.1425020695 seconds (5301.853123 samples/s, 0.0001886 s/sample) --- + +``` + +### 30 x 128 x 512 + +``` +python rnn.py -n basic_lstm -b 128 -l 512 -s 30 +Setup : compile + forward/backward x 1 +--- 3.72586607933 seconds +Forward: +--- 128000 samples in 14.6179850101 seconds (8756.336794 samples/s, 0.0001142 s/sample) --- +Forward + Backward: +--- 128000 samples in 32.6627261639 seconds (3918.840067 samples/s, 0.0002552 s/sample) --- + +``` + +### 30 x 128 x 1024 ``` -+ python rnn.py -n rnn -b 20 -i 100 -l 100 -s 30 +python rnn.py -n basic_lstm -b 128 -l 1024 -s 30 +Setup : compile + forward/backward x 1 +--- 3.84206986427 seconds Forward: ---- 100000 samples in 11.039894104 seconds (9058.057117 samples/s) --- +--- 128000 samples in 30.4814198017 seconds (4199.279457 samples/s, 0.0002381 s/sample) --- Forward + Backward: ---- 100000 samples in 25.300686121 seconds (3952.461833 samples/s) --- -+ python rnn.py -n rnn -b 20 -i 100 -l 500 -s 30 +--- 128000 samples in 75.8014390469 seconds (1688.622295 samples/s, 0.0005922 s/sample) --- + +``` + +## RNN + +### 30 x 32 x 128 + +``` +python rnn.py -n rnn -b 32 -l 128 -s 30 +Setup : compile + forward/backward x 1 +--- 1.6487121582 seconds Forward: ---- 100000 samples in 19.6222681999 seconds (5096.250552 samples/s) --- +--- 32000 samples in 3.56794595718 seconds (8968.745711 samples/s, 0.0001115 s/sample) --- Forward + Backward: ---- 100000 samples in 43.0670762062 seconds (2321.959292 samples/s) --- -+ python rnn.py -n basic_lstm -b 20 -i 100 -l 100 -s 30 +--- 32000 samples in 8.91037988663 seconds (3591.317139 samples/s, 0.0002784 s/sample) --- +``` + +### 30 x 32 x 512 + +``` +python rnn.py -n rnn -b 32 -l 512 -s 30 +Setup : compile + forward/backward x 1 +--- 1.62368106842 seconds Forward: ---- 100000 samples in 25.3170599937 seconds (3949.905568 samples/s) --- +--- 32000 samples in 6.98823904991 seconds (4579.122118 samples/s, 0.0002184 s/sample) --- Forward + Backward: ---- 100000 samples in 77.6742260456 seconds (1287.428310 samples/s) --- -+ python rnn.py -n basic_lstm -b 20 -i 100 -l 500 -s 30 +--- 32000 samples in 11.1912858486 seconds (2859.367586 samples/s, 0.0003497 s/sample) --- +``` + +### 30 x 32 x 1024 + +``` +python rnn.py -n rnn -b 32 -l 1024 -s 30 +Setup : compile + forward/backward x 1 +--- 1.72744393349 seconds Forward: ---- 100000 samples in 36.4037480354 seconds (2746.969825 samples/s) --- +--- 32000 samples in 7.73560094833 seconds (4136.718041 samples/s, 0.0002417 s/sample) --- Forward + Backward: ---- 100000 samples in 104.032881021 seconds (961.234534 samples/s) --- -+ python rnn.py -n lstm -b 20 -i 100 -l 100 -s 30 +--- 32000 samples in 16.9597899914 seconds (1886.815816 samples/s, 0.0005300 s/sample) --- +``` + +### 30 x 128 x 128 + +``` +python rnn.py -n rnn -b 128 -l 128 -s 30 +Setup : compile + forward/backward x 1 +--- 1.698335886 seconds Forward: ---- 100000 samples in 26.2394618988 seconds (3811.053590 samples/s) --- +--- 128000 samples in 4.29631710052 seconds (29792.959180 samples/s, 0.0000336 s/sample) --- Forward + Backward: ---- 100000 samples in 81.6460819244 seconds (1224.798498 samples/s) --- -+ python rnn.py -n lstm -b 20 -i 100 -l 500 -s 30 +--- 128000 samples in 9.66468191147 seconds (13244.098582 samples/s, 0.0000755 s/sample) --- +``` + +### 30 x 128 x 512 + +``` +python rnn.py -n rnn -b 128 -l 512 -s 30 +Setup : compile + forward/backward x 1 +--- 1.63733696938 seconds Forward: ---- 100000 samples in 36.3097510338 seconds (2754.080981 samples/s) --- +--- 128000 samples in 11.1102721691 seconds (11520.869881 samples/s, 0.0000868 s/sample) --- Forward + Backward: ---- 100000 samples in 104.501612902 seconds (956.923021 samples/s) --- +--- 128000 samples in 16.0786859989 seconds (7960.849538 samples/s, 0.0001256 s/sample) --- +``` + +### 30 x 128 x 1024 + ``` +python rnn.py -n rnn -b 128 -l 1024 -s 30 +Setup : compile + forward/backward x 1 +--- 1.7014939785 seconds +Forward: +--- 128000 samples in 17.6321749687 seconds (7259.456092 samples/s, 0.0001378 s/sample) --- +Forward + Backward: +--- 128000 samples in 28.4844169617 seconds (4493.685097 samples/s, 0.0002225 s/sample) --- + +``` + [1] Turning on performance governor: `sudo bash -c 'for i in ls /sys/devices/system/cpu/*/cpufreq/scaling_governor; do echo 'performance' > $i; done;'` diff --git a/tensorflow/rnn.py b/tensorflow/rnn.py index 5c55bcd..6bed0b9 100755 --- a/tensorflow/rnn.py +++ b/tensorflow/rnn.py @@ -12,8 +12,8 @@ def get_feed_dict(x_data, y_data=None): if y_data is not None: feed_dict[y] = y_data - for i in xrange(X_data.shape[1]): - feed_dict[x[i]] = x_data[:, i, :] + for i in xrange(x_data.shape[0]): + feed_dict[x[i]] = x_data[i, :, :] return feed_dict @@ -21,55 +21,63 @@ def get_feed_dict(x_data, y_data=None): # Parameters optparser = optparse.OptionParser() optparser.add_option("-n", "--network_type", default='rnn', help="Network type (rnn, lstm, basic_lstm)") -optparser.add_option("-i", "--input_size", default=100, type='int', help="Input layer size") optparser.add_option("-l", "--hidden_size", default=100, type='int', help="Hidden layer size") optparser.add_option("-s", "--seq_length", default=30, type='int', help="Sequence length") optparser.add_option("-b", "--batch_size", default=20, type='int', help="Batch size") opts = optparser.parse_args()[0] network_type = opts.network_type -input_size = opts.input_size +print(network_type) +hidden_size = opts.hidden_size hidden_size = opts.hidden_size seq_length = opts.seq_length batch_size = opts.batch_size -n_samples = 100000 +n_batch = 1000 +n_samples = batch_size * n_batch # Data -X_data = np.random.rand(n_samples, seq_length, input_size).astype(np.float32) -Y_data = np.random.rand(n_samples, hidden_size).astype(np.float32) - -x = [tf.placeholder(tf.float32, [batch_size, input_size], name="x") for i in range(seq_length)] -y = tf.placeholder(tf.float32, [batch_size, hidden_size], name="y") - -if network_type == 'rnn': - cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) -elif network_type == 'lstm': - cell = tf.nn.rnn_cell.LSTMCell(hidden_size, input_size) -elif network_type == 'basic_lstm': - cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size) -else: - raise Exception('Unknown network! '+network_type) - -output, _cell_state = rnn.rnn(cell, x, dtype=tf.float32) -cost = tf.reduce_sum((output[-1] - y) ** 2) - -optim = tf.train.GradientDescentOptimizer(0.01) -train_op = optim.minimize(cost) - -session = tf.Session() -session.run(tf.initialize_all_variables()) - -start = time.time() -for k, i in enumerate(xrange(0, n_samples, batch_size)): - session.run(output[-1], feed_dict=get_feed_dict(X_data[i:i+batch_size])) -end = time.time() -print "Forward:" -print "--- %i samples in %s seconds (%f samples/s, %.7f s/sample) ---" % (n_samples, end - start, n_samples / (end - start), (end - start) / n_samples) - -start = time.time() -for k, i in enumerate(xrange(0, n_samples, batch_size)): - session.run(train_op, feed_dict=get_feed_dict(X_data[i:i+batch_size], Y_data[i:i+batch_size])) -end = time.time() -print "Forward + Backward:" -print "--- %i samples in %s seconds (%f samples/s, %.7f s/sample) ---" % (n_samples, end - start, n_samples / (end - start), (end - start) / n_samples) +xinput = np.random.rand(seq_length, batch_size, hidden_size).astype(np.float32) +ytarget = np.random.rand(batch_size, hidden_size).astype(np.float32) + +with tf.device('/gpu:0'): + + x = [tf.placeholder(tf.float32, [batch_size, hidden_size], name="x") for i in range(seq_length)] + y = tf.placeholder(tf.float32, [batch_size, hidden_size], name="y") + + if network_type == 'rnn': + cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + elif network_type == 'lstm': + cell = tf.nn.rnn_cell.LSTMCell(hidden_size, hidden_size) + elif network_type == 'basic_lstm': + cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size) + else: + raise Exception('Unknown network! '+network_type) + + print "Compiling..." + start = time.time() + output, _cell_state = rnn.rnn(cell, x, dtype=tf.float32) + cost = tf.reduce_sum((output[-1] - y) ** 2) + + optim = tf.train.GradientDescentOptimizer(0.01) + train_op = optim.minimize(cost) + + session = tf.Session() + session.run(tf.initialize_all_variables()) + session.run(train_op, feed_dict=get_feed_dict(xinput, ytarget)) + print "Setup : compile + forward/backward x 1" + print "--- %s seconds" % (time.time() - start) + + start = time.time() + for i in xrange(0, n_batch): + session.run(output[-1], feed_dict=get_feed_dict(xinput)) + end = time.time() + print "Forward:" + print "--- %i samples in %s seconds (%f samples/s, %.7f s/sample) ---" % (n_samples, end - start, n_samples / (end - start), (end - start) / n_samples) + + start = time.time() + for i in xrange(0, n_batch): + session.run(train_op, feed_dict=get_feed_dict(xinput, ytarget)) + end = time.time() + print "Forward + Backward:" + print "--- %i samples in %s seconds (%f samples/s, %.7f s/sample) ---" % (n_samples, end - start, n_samples / (end - start), (end - start) / n_samples) diff --git a/theano/README.md b/theano/README.md new file mode 100644 index 0000000..feb8540 --- /dev/null +++ b/theano/README.md @@ -0,0 +1,183 @@ +# Theano Benchmark Log + +Cuda 7.5, cudnnv5, Intel(R) Xeon(R) CPU E5-2630L v3 @ 1.80GHz, Titan X. + +## Fast LSTM + + + +### 30 x 32 x 128 +``` +THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'fastlstm' -l 128 -s 30 -b 32 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 7.45822191238 seconds +Forward: +--- 32000 samples in 3.17055702209 seconds (10092.863739 samples/s, 0.0000991 s/sample) --- +Forward + Backward: +--- 32000 samples in 9.26702213287 seconds (3453.104950 samples/s, 0.0002896 s/sample) --- +``` +### 30 x 32 x 512 + + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'fastlstm' -l 512 -s 30 -b 32 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 7.58512711525 seconds +Forward: +--- 32000 samples in 6.42896199226 seconds (4977.475374 samples/s, 0.0002009 s/sample) --- +Forward + Backward: +--- 32000 samples in 19.8206739426 seconds (1614.475880 samples/s, 0.0006194 s/sample) --- +``` + +### 30 x 32 x 1024 + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'fastlstm' -l 1024 -s 30 -b 32 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 9.6281080246 seconds +Forward: +--- 32000 samples in 10.3716170788 seconds (3085.343371 samples/s, 0.0003241 s/sample) --- +Forward + Backward: +--- 32000 samples in 32.4317178726 seconds (986.688406 samples/s, 0.0010135 s/sample) --- +``` + +### 30 x 128 x 128 + + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'fastlstm' -l 128 -s 30 -b 128 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 7.37970685959 seconds +Forward: +--- 128000 samples in 3.27810716629 seconds (39046.923577 samples/s, 0.0000256 s/sample) --- +Forward + Backward: +--- 128000 samples in 13.1759991646 seconds (9714.633281 samples/s, 0.0001029 s/sample) -- +``` + +### 30 x 128 x 512 + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'fastlstm' -l 512 -s 30 -b 128 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 7.49780893326 seconds +Forward: +--- 128000 samples in 8.03891611099 seconds (15922.544561 samples/s, 0.0000628 s/sample) --- +Forward + Backward: +--- 128000 samples in 32.7736029625 seconds (3905.582189 samples/s, 0.0002560 s/sample) --- + +``` + +### 30 x 128 x 1024 + + +``` +THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'fastlstm' -l 1024 -s 30 -b 128 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 7.44703698158 seconds +Forward: +--- 128000 samples in 20.5059478283 seconds (6242.091371 samples/s, 0.0001602 s/sample) --- +Forward + Backward: +--- 128000 samples in 74.6807880402 seconds (1713.961560 samples/s, 0.0005834 s/sample) --- +``` + +## RNN + + +### 30 x 32 x 128 + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'rnn' -l 128 -s 30 -b 32 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 4.309237957 seconds +Forward: +--- 32000 samples in 0.989920139313 seconds (32325.839963 samples/s, 0.0000309 s/sample) --- +Forward + Backward: +--- 32000 samples in 3.34791088104 seconds (9558.199467 samples/s, 0.0001046 s/sample) --- +``` + +### 30 x 32 x 512 + + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'rnn' -l 512 -s 30 -b 32 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 4.36186599731 seconds +Forward: +--- 32000 samples in 3.27020597458 seconds (9785.316353 samples/s, 0.0001022 s/sample) --- +Forward + Backward: +--- 32000 samples in 8.80706095695 seconds (3633.448225 samples/s, 0.0002752 s/sample) --- +``` + +### 30 x 32 x 1024 + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'rnn' -l 1024 -s 30 -b 32 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 4.44132804871 seconds +Forward: +--- 32000 samples in 5.74468803406 seconds (5570.363405 samples/s, 0.0001795 s/sample) --- +Forward + Backward: +--- 32000 samples in 14.2010200024 seconds (2253.359265 samples/s, 0.0004438 s/sample) --- + +``` + +### 30 x 128 x 128 + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'rnn' -l 128 -s 30 -b 128 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 4.48347306252 seconds +Forward: +--- 128000 samples in 1.74959516525 seconds (73159.781498 samples/s, 0.0000137 s/sample) --- +Forward + Backward: +--- 128000 samples in 5.81079101562 seconds (22027.982018 samples/s, 0.0000454 s/sample) --- + +``` + +### 30 x 128 x 512 + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'rnn' -l 512 -s 30 -b 128 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 4.40771007538 seconds +Forward: +--- 128000 samples in 3.04104089737 seconds (42090.851231 samples/s, 0.0000238 s/sample) --- +Forward + Backward: +--- 128000 samples in 10.1157169342 seconds (12653.576690 samples/s, 0.0000790 s/sample) --- +``` + +### 30 x 128 x 1024 + +``` +$ THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python rnn.py -n 'rnn' -l 1024 -s 30 -b 128 +Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available) +Compiling... +Setup : compile + forward/backward x 1 +--- 4.38037991524 seconds +Forward: +--- 128000 samples in 6.43677687645 seconds (19885.728907 samples/s, 0.0000503 s/sample) --- +Forward + Backward: +--- 128000 samples in 18.919303894 seconds (6765.576615 samples/s, 0.0001478 s/sample) --- +``` diff --git a/theano/rnn.py b/theano/rnn.py index 65120d5..b8c7e50 100755 --- a/theano/rnn.py +++ b/theano/rnn.py @@ -219,14 +219,12 @@ def recurrence(x_t, c_tm1, h_tm1): optparser = optparse.OptionParser() optparser.add_option("-n", "--network_type", default='rnn', help="Network type (rnn, lstm, fastlstm)") -optparser.add_option("-i", "--input_size", default=100, type='int', help="Input layer size") -optparser.add_option("-l", "--hidden_size", default=100, type='int', help="Hidden layer size") +optparser.add_option("-l", "--hidden_size", default=128, type='int', help="Hidden layer size") optparser.add_option("-s", "--seq_length", default=30, type='int', help="Sequence length") -optparser.add_option("-b", "--batch_size", default=20, type='int', help="Batch size") +optparser.add_option("-b", "--batch_size", default=32, type='int', help="Batch size") opts = optparser.parse_args()[0] network_type = opts.network_type -input_size = opts.input_size hidden_size = opts.hidden_size seq_length = opts.seq_length batch_size = opts.batch_size @@ -234,9 +232,9 @@ def recurrence(x_t, c_tm1, h_tm1): # Data -n_samples = 100000 -x_values = theano.shared(np.random.rand(n_samples, seq_length, input_size).astype(np.float32)) -y_values = theano.shared(np.random.rand(n_samples, hidden_size).astype(np.float32)) +n_batch = 1000 +xinput = theano.shared(np.random.rand(seq_length, batch_size, hidden_size).astype(np.float32)) +ytarget = theano.shared(np.random.rand(batch_size, hidden_size).astype(np.float32)) # Network @@ -248,39 +246,41 @@ def recurrence(x_t, c_tm1, h_tm1): y = T.fmatrix() if network_type == 'rnn': - rnn = RNN(input_size, hidden_size) + rnn = RNN(hidden_size, hidden_size) elif network_type == 'lstm': - rnn = LSTM(input_size, hidden_size) + rnn = LSTM(hidden_size, hidden_size) elif network_type == 'fastlstm': - rnn = FastLSTM(input_size, hidden_size) + rnn = FastLSTM(hidden_size, hidden_size) else: raise Exception('Unknown network!') -output = rnn.link(x) +output = rnn.link(x.dimshuffle(1, 0, 2)) cost = ((output - y) ** 2).mean() updates = [(p, p - theano.shared(np.float32(0.01)) * g) for p, g in zip(rnn.params, T.grad(cost, rnn.params))] -print 'compiling...' -f_test = theano.function(inputs=[index], outputs=output, givens={x: x_values[index:index + batch_size]}) -f_train = theano.function(inputs=[index], outputs=cost, updates=updates, givens={x: x_values[index:index + batch_size], y: y_values[index:index + batch_size]}) -f_train(0) +print 'Compiling...' +f_test = theano.function(inputs=[], outputs=output, givens={x: xinput}) +f_train = theano.function(inputs=[], outputs=cost, updates=updates, givens={x: xinput, y: ytarget}) +f_train() +theano.sandbox.cuda.synchronize() print "Setup : compile + forward/backward x 1" print "--- %s seconds" % (time.time() - start) +n_samples = n_batch * batch_size start = time.time() -for k, i in enumerate(xrange(0, n_samples, batch_size)): - # if k % 100 == 0: - # print k - f_test(i) +for i in xrange(0, n_batch): + f_test() +theano.sandbox.cuda.synchronize() end = time.time() print "Forward:" print "--- %i samples in %s seconds (%f samples/s, %.7f s/sample) ---" % (n_samples, end - start, n_samples / (end - start), (end - start) / n_samples) start = time.time() -for k, i in enumerate(xrange(0, n_samples, batch_size)): +for i in xrange(0, n_batch): # if k % 100 == 0: # print k - f_train(i) + f_train() +theano.sandbox.cuda.synchronize() end = time.time() print "Forward + Backward:" print "--- %i samples in %s seconds (%f samples/s, %.7f s/sample) ---" % (n_samples, end - start, n_samples / (end - start), (end - start) / n_samples) diff --git a/torch/README.md b/torch/README.md new file mode 100644 index 0000000..c3451d4 --- /dev/null +++ b/torch/README.md @@ -0,0 +1,165 @@ +# Torch Benchmark + +Provided by [Nicholas Leonard](https://github.com/nicholas-leonard). + +Benchmark script uses [Element-Research/rnn](https://github.com/Element-Research/rnn). + +Lua 5.2, Cuda 7.5, cudnnv5, Intel(R) Xeon(R) CPU E5-2630L v3 @ 1.80GHz, Titan X: + + +## Fast LSTM + + +### 30 x 32 x 128 + +``` +$ th rnn.lua -network 'fastlstm' -batchsize 32 -hiddensize 128 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.024899005889893 seconds --- +Forward: +--- 32000 samples in 3.1959130764008 seconds (10012.885074946 samples/s, 99.871315062046 microsec/samples) --- +Forward + Backward: +--- 32000 samples in 13.899139881134 seconds (2302.3021987976 samples/s, 434.34784561396 microsec/samples) --- +``` + +### 30 x 32 x 512 + +``` +$ th rnn.lua -network 'fastlstm' -batchsize 32 -hiddensize 512 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.18875980377197 seconds --- +Forward: +--- 32000 samples in 6.4531669616699 seconds (4958.8108406272 samples/s, 201.66125148535 microsec/samples) --- +Forward + Backward: +--- 32000 samples in 19.541891098022 seconds (1637.5083655011 samples/s, 610.68390309811 microsec/samples) --- +``` + +### 30 x 32 x 1024 + +``` +$ th rnn.lua -network 'fastlstm' -batchsize 32 -hiddensize 1024 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.69268393516541 seconds --- +Forward: +--- 32000 samples in 11.082577943802 seconds (2887.4174470646 samples/s, 346.33024781942 microsec/samples) --- +Forward + Backward: +--- 32000 samples in 36.474525928497 seconds (877.32484315331 samples/s, 1139.8286595941 microsec/samples) --- +``` + +### 30 x 128 x 128 + + +``` +$ th rnn.lua -network 'fastlstm' -batchsize 128 -hiddensize 128 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.028716802597046 seconds --- +Forward: +--- 128000 samples in 3.2250719070435 seconds (39689.110895787 samples/s, 25.195827707648 microsec/samples) --- +Forward + Backward: +--- 128000 samples in 14.058291912079 seconds (9104.9498912316 samples/s, 109.83036831021 microsec/samples) --- +``` + +### 30 x 128 x 512 + + +``` +$ th rnn.lua -network 'fastlstm' -batchsize 128 -hiddensize 512 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.19667100906372 seconds --- +Forward: +--- 128000 samples in 6.5813970565796 seconds (19448.779340937 samples/s, 51.417108625174 microsec/samples) --- +Forward + Backward: +--- 128000 samples in 27.426359891891 seconds (4667.0445070921 samples/s, 214.26836587489 microsec/samples) --- +``` + +### 30 x 128 x 1024 + +``` +$ th rnn.lua -network 'fastlstm' -batchsize 128 -hiddensize 1024 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.74531388282776 seconds --- +Forward: +--- 128000 samples in 14.383507966995 seconds (8899.0845165442 samples/s, 112.37110942602 microsec/samples) --- +Forward + Backward: +--- 128000 samples in 71.433391094208 seconds (1791.8792478834 samples/s, 558.07331949472 microsec/samples) --- +``` + + + + +## RNN + + +### 30 x 32 x 128 + +``` +$ th rnn.lua -network 'rnn' -batchsize 32 -hiddensize 128 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.045458793640137 seconds --- +Forward: +--- 32000 samples in 3.2980129718781 seconds (9702.8295844296 samples/s, 103.06271910667 microsec/samples) --- +Forward + Backward: +--- 32000 samples in 8.305154800415 seconds (3853.0314888602 samples/s, 259.53590124846 microsec/samples) --- + +``` + +### 30 x 32 x 512 + +``` +$ th rnn.lua -network 'rnn' -batchsize 32 -hiddensize 512 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.053925037384033 seconds --- +Forward: +--- 32000 samples in 3.6663720607758 seconds (8727.9910213711 samples/s, 114.57390338182 microsec/samples) --- +Forward + Backward: +--- 32000 samples in 9.2218749523163 seconds (3470.0127856443 samples/s, 288.18337619305 microsec/samples) --- +``` + +### 30 x 32 x 1024 + +``` +$ th rnn.lua -network 'rnn' -batchsize 32 -hiddensize 1024 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.08701491355896 seconds --- +Forward: +--- 32000 samples in 3.8027799129486 seconds (8414.9119629321 samples/s, 118.83665621281 microsec/samples) --- +Forward + Backward: +--- 32000 samples in 12.205145835876 seconds (2621.8464374057 samples/s, 381.4105913043 microsec/samples) --- +``` + +### 30 x 128 x 128 + +``` +$ th rnn.lua -network 'rnn' -batchsize 128 -hiddensize 128 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.078629016876221 seconds --- +Forward: +--- 128000 samples in 4.1859209537506 seconds (30578.752442332 samples/s, 32.702445983887 microsec/samples) --- +Forward + Backward: +--- 128000 samples in 8.6592428684235 seconds (14781.904624814 samples/s, 67.650280892849 microsec/samples) --- +``` + +### 30 x 128 x 512 + +``` +$ th rnn.lua -network 'rnn' -batchsize 128 -hiddensize 512 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.088251113891602 seconds --- +Forward: +--- 128000 samples in 4.383120059967 seconds (29203.014867419 samples/s, 34.243039786816 microsec/samples) --- +Forward + Backward: +--- 128000 samples in 9.4049069881439 seconds (13609.928358313 samples/s, 73.475772514939 microsec/samples) --- +``` + +### 30 x 128 x 1024 + +``` +$ th rnn.lua -network 'rnn' -batchsize 128 -hiddensize 1024 -seqlen 30 +Setup : compile + forward/backward x 1 +--- 0.12880301475525 seconds --- +Forward: +--- 128000 samples in 8.2753868103027 seconds (15467.566064044 samples/s, 64.651412889361 microsec/samples) --- +Forward + Backward: +--- 128000 samples in 19.230028152466 seconds (6656.2610449056 samples/s, 150.23449249566 microsec/samples) --- +``` + diff --git a/torch/rnn.lua b/torch/rnn.lua index c0bfcb3..93e210c 100644 --- a/torch/rnn.lua +++ b/torch/rnn.lua @@ -13,74 +13,77 @@ nn.FastLSTM.usenngraph = true cmd = torch.CmdLine() cmd:text() cmd:text('Options') -cmd:option('-nSamples', 100000, 'Number of samples') -cmd:option('-networkType', 'lstm', 'Network type') -cmd:option('-inputSize', 100, 'Neural network input size') -cmd:option('-hiddenSize', 100, 'Neural network hidden layer size') -cmd:option('-seqLength', 30, 'Sequence length') -cmd:option('-batchSize', 20, 'Batch size') -cmd:option('-cpu', false, 'Run on CPU') +cmd:option('-nbatch', 1000, 'Number of samples') +cmd:option('-network', 'fastlstm', 'Network type') +cmd:option('-hiddensize', 128, 'Neural network input and output size') +cmd:option('-seqlen', 30, 'Sequence length') +cmd:option('-batchsize', 20, 'Batch size') cmd:text() for k, v in pairs(cmd:parse(arg)) do _G[k] = v end -local xValues = torch.rand(nSamples, seqLength * inputSize) -local yValues = torch.rand(nSamples, hiddenSize) -if cpu ~= true then - xValues = xValues:cuda() - yValues = yValues:cuda() -end -local xBatches = xValues:split(batchSize, 1) -local yBatches = yValues:split(batchSize, 1) +local input = torch.rand(seqlen, batchsize, hiddensize):cuda() +local target = torch.rand(batchsize, hiddensize):cuda() local a = torch.Timer() local rnn -if networkType == 'rnn' then +if network == 'rnn' then rnn = nn.Sequential() :add(nn.JoinTable(1,1)) - :add(nn.Linear(inputSize+hiddenSize, hiddenSize)) + :add(nn.Linear(hiddensize*2, hiddensize)) :add(nn.Sigmoid()) - rnn = nn.Recurrence(rnn, hiddenSize, 1) -elseif networkType == 'lstm' then - rnn = nn.FastLSTM(inputSize, hiddenSize) + rnn = nn.Recurrence(rnn, hiddensize, 1) + rnn = nn.Sequential() + :add(nn.Sequencer(rnn)) + :add(nn.Select(1,-1)) +elseif network == 'lstm' then -- ( no peephole connections) + rnn = nn.LSTM(hiddensize, hiddensize) + rnn = nn.Sequential() + :add(nn.Sequencer(rnn)) + :add(nn.Select(1,-1)) +elseif network == 'oldfastlstm' then -- ( no peephole connections) + rnn = nn.FastLSTM(hiddensize, hiddensize) + rnn = nn.Sequential() + :add(nn.Sequencer(rnn)) + :add(nn.Select(1,-1)) +elseif network == 'fastlstm' then -- like fastlstm but faster ( no peephole connections) + rnn = nn.SeqLSTM(hiddensize, hiddensize) + rnn = nn.Sequential() + :add(rnn) + :add(nn.Select(1,-1)) else error('Unkown network type!') end -local rnn = nn.Sequential():add(nn.Sequencer(rnn)):add(nn.SelectTable(-1)) + local criterion = nn.MSECriterion() if cpu ~= true then rnn:cuda() criterion:cuda() end -local input = xBatches[1]:split(inputSize, 2) -criterion:forward(rnn:forward(input), yBatches[1]) -rnn:backward(input, criterion:backward(rnn.output, yBatches[1])) -if cpu ~= true then cutorch.synchronize() end +criterion:forward(rnn:forward(input), target) +rnn:backward(input, criterion:backward(rnn.output, target)) +cutorch.synchronize() print("Setup : compile + forward/backward x 1") print("--- " .. a:time().real .. " seconds ---") a:reset() -for i = 1, #xBatches do - if (i % 1000 == 0) then - print(i) - end - rnn:forward(xBatches[i]:split(inputSize, 2)) +for i = 1, nbatch do + rnn:forward(input) end -if cpu ~= true then cutorch.synchronize() end +cutorch.synchronize() print("Forward:") -print("--- " .. nSamples .. " samples in " .. a:time().real .. " seconds (" .. nSamples / a:time().real .. " samples/s) ---") +local nSamples = nbatch * batchsize +local speed = nSamples / a:time().real +print("--- " .. nSamples .. " samples in " .. a:time().real .. " seconds (" .. speed .. " samples/s, " .. 1000000/speed .. " microsec/samples) ---") a:reset() -for i = 1, #xBatches do - if (i % 1000 == 0) then - print(i) - end - local input = xBatches[i]:split(inputSize, 2) - criterion:forward(rnn:forward(input), yBatches[i]) +for i = 1, nbatch do + criterion:forward(rnn:forward(input), target) rnn:zeroGradParameters() - rnn:backward(input, criterion:backward(rnn.output, yBatches[i])) + rnn:backward(input, criterion:backward(rnn.output, target)) rnn:updateParameters(0.01) end -if cpu ~= true then cutorch.synchronize() end +cutorch.synchronize() print("Forward + Backward:") -print("--- " .. nSamples .. " samples in " .. a:time().real .. " seconds (" .. nSamples / a:time().real .. " samples/s) ---") +local speed = nSamples / a:time().real +print("--- " .. nSamples .. " samples in " .. a:time().real .. " seconds (" .. speed .. " samples/s, " .. 1000000/speed .. " microsec/samples) ---")