Skip to main content

Initial support for Long Short Term Memory (LSTM) in DLL

I'm really happy to announce that I just merged support for

Long Short Term Memory (LSTM) cells into my Deep Learning Library (DLL) machine learning framework. Two weeks ago, I already merged suport for Recurrent Neural network (RNN).

It's nothing fancy yet, but forward propagation of LSTM and basic Backpropagation Through Time (BPTT) are now supported. It was not really complicated to implemenet the forward pass but the backward pass is much complicated for an LSTM than for a RNN. It took me quite a long time to figure out all the gradients formulas and the documentation on that is quite scarce.

For now, still only existing classification loss is supported for RNN and LSTM. As I said last time, I still plan to add support for sequence-to-sequence loss in order to be able to train models able to generate characters. However, I don't know when I'll be able to work on that. Now that I've got the code for LSTM, I should be able to implement a GRU cell and NAS cell quite easily I believe.

For example, here is a simple LSTM used on MNIST for classification:

#include "dll/neural/dense_layer.hpp"
#include "dll/neural/lstm_layer.hpp"
#include "dll/neural/recurrent_last_layer.hpp"
#include "dll/network.hpp"
#include "dll/datasets.hpp"

int main(int /*argc*/, char* /*argv*/ []) {
    // Load the dataset
    auto dataset = dll::make_mnist_dataset_nc(dll::batch_size<100>{}, dll::scale_pre<255>{});

    constexpr size_t time_steps      = 28;
    constexpr size_t sequence_length = 28;
    constexpr size_t hidden_units    = 100;

    // Build the network

    using network_t = dll::dyn_network_desc<
        dll::network_layers<
            dll::lstm_layer<time_steps, sequence_length, hidden_units, dll::last_only>,
            dll::recurrent_last_layer<time_steps, hidden_units>,
            dll::dense_layer<hidden_units, 10, dll::softmax>
        >
        , dll::updater<dll::updater_type::ADAM>      // Adam
        , dll::batch_size<100>                       // The mini-batch size
    >::network_t;

    auto net = std::make_unique<network_t>();

    // Display the network and dataset
    net->display();
    dataset.display();

    // Train the network for performance sake
    net->fine_tune(dataset.train(), 50);

    // Test the network on test set
    net->evaluate(dataset.test());

    return 0;
}

The network is quite similar to the one used previously with an RNN, just replace rnn with lstm and that's it. It starts with LSTM layer, followed by a layer extracting the last time step and finally a dense layer with a softmax function. The network is trained with Adam for 50 epochs. You can change the activation function , the initializer for the weights and the biases and number of steps for BPTT truncation.

Here is the result I got on my last run:

------------------------------------------------------------
| Index | Layer                | Parameters | Output Shape |
------------------------------------------------------------
| 0     | LSTM (TANH) (dyn)    |      51200 | [Bx28x100]   |
| 1     | RNN(last)            |          0 | [Bx100]      |
| 2     | Dense(SOFTMAX) (dyn) |       1000 | [Bx10]       |
------------------------------------------------------------
              Total Parameters:      52200

--------------------------------------------
| mnist | Size  | Batches | Augmented Size |
--------------------------------------------
| train | 60000 | 600     | 60000          |
| test  | 10000 | 100     | 10000          |
--------------------------------------------

Network with 3 layers
    LSTM(dyn): 28x28 -> TANH -> 28x100
    RNN(last): 28x100 -> 100
    Dense(dyn): 100 -> SOFTMAX -> 10
Total parameters: 52200
Dataset
Training: In-Memory Data Generator
              Size: 60000
           Batches: 600
Testing: In-Memory Data Generator
              Size: 10000
           Batches: 100

Train the network with "Stochastic Gradient Descent"
    Updater: ADAM
       Loss: CATEGORICAL_CROSS_ENTROPY
 Early Stop: Goal(error)

With parameters:
          epochs=50
      batch_size=100
   learning_rate=0.001
           beta1=0.9
           beta2=0.999

epoch   0/50 batch  600/ 600 - error: 0.07943 loss: 0.28504 time 20910ms
epoch   1/50 batch  600/ 600 - error: 0.06683 loss: 0.24021 time 20889ms
epoch   2/50 batch  600/ 600 - error: 0.04828 loss: 0.18233 time 21061ms
epoch   3/50 batch  600/ 600 - error: 0.04407 loss: 0.16665 time 20839ms
epoch   4/50 batch  600/ 600 - error: 0.03515 loss: 0.13290 time 22108ms
epoch   5/50 batch  600/ 600 - error: 0.03207 loss: 0.12019 time 21393ms
epoch   6/50 batch  600/ 600 - error: 0.02973 loss: 0.11239 time 28199ms
epoch   7/50 batch  600/ 600 - error: 0.02653 loss: 0.10455 time 37039ms
epoch   8/50 batch  600/ 600 - error: 0.02482 loss: 0.09657 time 23127ms
epoch   9/50 batch  600/ 600 - error: 0.02177 loss: 0.08422 time 41766ms
epoch  10/50 batch  600/ 600 - error: 0.02453 loss: 0.09382 time 29765ms
epoch  11/50 batch  600/ 600 - error: 0.02575 loss: 0.09796 time 21449ms
epoch  12/50 batch  600/ 600 - error: 0.02107 loss: 0.07833 time 42056ms
epoch  13/50 batch  600/ 600 - error: 0.01877 loss: 0.07171 time 24673ms
epoch  14/50 batch  600/ 600 - error: 0.02095 loss: 0.08481 time 20878ms
epoch  15/50 batch  600/ 600 - error: 0.02040 loss: 0.07578 time 41515ms
epoch  16/50 batch  600/ 600 - error: 0.01580 loss: 0.06083 time 25705ms
epoch  17/50 batch  600/ 600 - error: 0.01945 loss: 0.07046 time 20903ms
epoch  18/50 batch  600/ 600 - error: 0.01728 loss: 0.06683 time 41828ms
epoch  19/50 batch  600/ 600 - error: 0.01577 loss: 0.05947 time 27810ms
epoch  20/50 batch  600/ 600 - error: 0.01528 loss: 0.05883 time 21477ms
epoch  21/50 batch  600/ 600 - error: 0.01345 loss: 0.05127 time 44718ms
epoch  22/50 batch  600/ 600 - error: 0.01410 loss: 0.05357 time 25174ms
epoch  23/50 batch  600/ 600 - error: 0.01268 loss: 0.04765 time 23827ms
epoch  24/50 batch  600/ 600 - error: 0.01342 loss: 0.05004 time 47232ms
epoch  25/50 batch  600/ 600 - error: 0.01730 loss: 0.06872 time 22532ms
epoch  26/50 batch  600/ 600 - error: 0.01337 loss: 0.05016 time 30114ms
epoch  27/50 batch  600/ 600 - error: 0.01842 loss: 0.07049 time 40136ms
epoch  28/50 batch  600/ 600 - error: 0.01262 loss: 0.04639 time 21793ms
epoch  29/50 batch  600/ 600 - error: 0.01403 loss: 0.05292 time 34096ms
epoch  30/50 batch  600/ 600 - error: 0.01185 loss: 0.04456 time 35420ms
epoch  31/50 batch  600/ 600 - error: 0.01098 loss: 0.04180 time 20909ms
epoch  32/50 batch  600/ 600 - error: 0.01337 loss: 0.04687 time 30113ms
epoch  33/50 batch  600/ 600 - error: 0.01415 loss: 0.05292 time 37393ms
epoch  34/50 batch  600/ 600 - error: 0.00982 loss: 0.03615 time 20962ms
epoch  35/50 batch  600/ 600 - error: 0.01178 loss: 0.04830 time 29305ms
epoch  36/50 batch  600/ 600 - error: 0.00882 loss: 0.03408 time 38293ms
epoch  37/50 batch  600/ 600 - error: 0.01148 loss: 0.04341 time 20841ms
epoch  38/50 batch  600/ 600 - error: 0.00960 loss: 0.03701 time 29204ms
epoch  39/50 batch  600/ 600 - error: 0.00850 loss: 0.03094 time 39802ms
epoch  40/50 batch  600/ 600 - error: 0.01473 loss: 0.05136 time 20831ms
epoch  41/50 batch  600/ 600 - error: 0.01007 loss: 0.03579 time 29856ms
epoch  42/50 batch  600/ 600 - error: 0.00943 loss: 0.03370 time 38200ms
epoch  43/50 batch  600/ 600 - error: 0.01205 loss: 0.04409 time 21162ms
epoch  44/50 batch  600/ 600 - error: 0.00980 loss: 0.03674 time 32279ms
epoch  45/50 batch  600/ 600 - error: 0.01068 loss: 0.04133 time 38448ms
epoch  46/50 batch  600/ 600 - error: 0.00913 loss: 0.03478 time 20797ms
epoch  47/50 batch  600/ 600 - error: 0.00985 loss: 0.03759 time 28885ms
epoch  48/50 batch  600/ 600 - error: 0.00912 loss: 0.03295 time 41120ms
epoch  49/50 batch  600/ 600 - error: 0.00930 loss: 0.03438 time 21282ms
Restore the best (error) weights from epoch 39
Training took 1460s

Evaluation Results
   error: 0.02440
    loss: 0.11315
evaluation took 1000ms

Again, nothing fancy yet, but this example has not been optimized for performance nor for accuracy.

I also made a few changes to the RNN layer. I added support for biases and improved the code as well for performance and readability.

All this support is now in the master branch of the DLL project if you want to check it out. You can also check out the example online: mnist_lstm.cpp

You can access the project on Github.

Currently I'm working on the GPU performance again. The performance of some is still not as good as I want it to be, especially complex operation like used in Adam and Nadam. Currently, there are many calls to GPU BLAS libraries and I want to try to extract some more optimized patterns. Once it's done, I'll post more on that later on the blog.

Comments

Comments powered by Disqus