Initial support for Recurrent Neural Network (RNN) in DLL

I'm happy to announce that I just merged support for Recurrent Neural Networks (RNNs) into my Deep Learning Library (DLL) machine learning framework.

It's nothing fancy yet, but forward propagation of RNN and basic Backpropagation Through Time (BPTT) are now supported. For now, only existing classification loss is supported for RNN. I plan to add support for sequence-to-sequence loss in order to be able to train models able to generate characters, but I don't know when I'll be able to work on that. I also plan to add support for other types of cells such as LSTM and GRU (maybe NAS) in the future.

For example, here is a simple RNN used on MNIST:

#include "dll/neural/dense_layer.hpp"
#include "dll/neural/recurrent_layer.hpp"
#include "dll/neural/recurrent_last_layer.hpp"
#include "dll/network.hpp"
#include "dll/datasets.hpp"

int main(int /*argc*/, char* /*argv*/ []) {
    // Load the dataset
    auto dataset = dll::make_mnist_dataset_nc(dll::batch_size<100>{}, dll::scale_pre<255>{});

    constexpr size_t time_steps      = 28;
    constexpr size_t sequence_length = 28;
    constexpr size_t hidden_units    = 100;

    // Build the network

    using network_t = dll::dyn_network_desc<
        dll::network_layers<
            dll::recurrent_layer<time_steps, sequence_length, hidden_units, dll::last_only>,
            dll::recurrent_last_layer<time_steps, hidden_units>,
            dll::dense_layer<hidden_units, 10, dll::softmax>
        >
        , dll::updater<dll::updater_type::ADAM>      // Adam
        , dll::batch_size<100>                       // The mini-batch size
    >::network_t;

    auto net = std::make_unique<network_t>();

    // Display the network and dataset
    net->display();

    // Train the network for performance sake
    net->fine_tune(dataset.train(), 50);

    // Test the network on test set
    net->evaluate(dataset.test());

    return 0;
}

The network starts with recurrent layer, followed by a layer that extracts only the last layer and finally a dense layer with a softmax function. The recurrent layer has support to change the activation function, change the initializer for the two weights matrices of the RNN and the number of steps for BPTT truncation.

Here is a possible result:

Network with 3 layers
    RNN(dyn): 28x28 -> TANH -> 28x100
    RNN(last): 28x100 -> 100
    Dense(dyn): 100 -> SOFTMAX -> 10
Total parameters: 13800
Train the network with "Stochastic Gradient Descent"
    Updater: ADAM
       Loss: CATEGORICAL_CROSS_ENTROPY
 Early Stop: Goal(error)

With parameters:
          epochs=50
      batch_size=100
   learning_rate=0.001
           beta1=0.9
           beta2=0.999

Epoch   0/50 - Classification error: 0.11635 Loss: 0.39999 Time 4717ms
Epoch   1/50 - Classification error: 0.11303 Loss: 0.36994 Time 4702ms
Epoch   2/50 - Classification error: 0.06732 Loss: 0.23469 Time 4702ms
Epoch   3/50 - Classification error: 0.04865 Loss: 0.17091 Time 4696ms
Epoch   4/50 - Classification error: 0.05957 Loss: 0.20437 Time 4706ms
Epoch   5/50 - Classification error: 0.05022 Loss: 0.16888 Time 4696ms
Epoch   6/50 - Classification error: 0.03912 Loss: 0.13743 Time 4698ms
Epoch   7/50 - Classification error: 0.04097 Loss: 0.14509 Time 4706ms
Epoch   8/50 - Classification error: 0.03938 Loss: 0.13397 Time 4694ms
Epoch   9/50 - Classification error: 0.03525 Loss: 0.12284 Time 4706ms
Epoch  10/50 - Classification error: 0.03927 Loss: 0.13770 Time 4694ms
Epoch  11/50 - Classification error: 0.03315 Loss: 0.11315 Time 4711ms
Epoch  12/50 - Classification error: 0.05037 Loss: 0.17123 Time 4711ms
Epoch  13/50 - Classification error: 0.02927 Loss: 0.10042 Time 4780ms
Epoch  14/50 - Classification error: 0.03322 Loss: 0.11027 Time 4746ms
Epoch  15/50 - Classification error: 0.03397 Loss: 0.11585 Time 4684ms
Epoch  16/50 - Classification error: 0.02938 Loss: 0.09984 Time 4708ms
Epoch  17/50 - Classification error: 0.03262 Loss: 0.11152 Time 4690ms
Epoch  18/50 - Classification error: 0.02872 Loss: 0.09753 Time 4672ms
Epoch  19/50 - Classification error: 0.02548 Loss: 0.08605 Time 4691ms
Epoch  20/50 - Classification error: 0.02245 Loss: 0.07797 Time 4693ms
Epoch  21/50 - Classification error: 0.02705 Loss: 0.08984 Time 4684ms
Epoch  22/50 - Classification error: 0.02422 Loss: 0.08164 Time 4688ms
Epoch  23/50 - Classification error: 0.02645 Loss: 0.08804 Time 4690ms
Epoch  24/50 - Classification error: 0.02927 Loss: 0.09739 Time 4715ms
Epoch  25/50 - Classification error: 0.02578 Loss: 0.08669 Time 4702ms
Epoch  26/50 - Classification error: 0.02785 Loss: 0.09368 Time 4700ms
Epoch  27/50 - Classification error: 0.02472 Loss: 0.08237 Time 4695ms
Epoch  28/50 - Classification error: 0.02125 Loss: 0.07324 Time 4690ms
Epoch  29/50 - Classification error: 0.01977 Loss: 0.06635 Time 4688ms
Epoch  30/50 - Classification error: 0.03635 Loss: 0.12140 Time 4689ms
Epoch  31/50 - Classification error: 0.02862 Loss: 0.09704 Time 4698ms
Epoch  32/50 - Classification error: 0.02463 Loss: 0.08158 Time 4686ms
Epoch  33/50 - Classification error: 0.02565 Loss: 0.08771 Time 4697ms
Epoch  34/50 - Classification error: 0.02278 Loss: 0.07634 Time 4718ms
Epoch  35/50 - Classification error: 0.02105 Loss: 0.07075 Time 4697ms
Epoch  36/50 - Classification error: 0.02770 Loss: 0.09358 Time 4711ms
Epoch  37/50 - Classification error: 0.02627 Loss: 0.08805 Time 4742ms
Epoch  38/50 - Classification error: 0.02282 Loss: 0.07712 Time 4708ms
Epoch  39/50 - Classification error: 0.02305 Loss: 0.07661 Time 4697ms
Epoch  40/50 - Classification error: 0.02243 Loss: 0.07773 Time 4700ms
Epoch  41/50 - Classification error: 0.02467 Loss: 0.08234 Time 4712ms
Epoch  42/50 - Classification error: 0.01808 Loss: 0.06186 Time 4691ms
Epoch  43/50 - Classification error: 0.02388 Loss: 0.07917 Time 4681ms
Epoch  44/50 - Classification error: 0.02162 Loss: 0.07508 Time 4699ms
Epoch  45/50 - Classification error: 0.01877 Loss: 0.06289 Time 4735ms
Epoch  46/50 - Classification error: 0.02263 Loss: 0.07969 Time 4764ms
Epoch  47/50 - Classification error: 0.02100 Loss: 0.07207 Time 4684ms
Epoch  48/50 - Classification error: 0.02425 Loss: 0.08076 Time 4752ms
Epoch  49/50 - Classification error: 0.02328 Loss: 0.07803 Time 4718ms
Restore the best (error) weights from epoch 42
Training took 235s
Evaluation Results
   error: 0.03000
    loss: 0.12260
evaluation took 245ms

Nothing fancy, but this example is not necessarily optimized.

All this support is now in the master branch of the DLL project if you want to check it out. You can also check out the example online: mnist_rnn.cpp

You can access the project on Github.

Related articles

  • Deep Learning Library 1.0 - Fast Neural Network Library
  • Initial support for Long Short Term Memory (LSTM) in DLL
  • DLL: Pretty printing and live output
  • Update on Deep Learning Library (DLL): Dropout, Batch Normalization, Adaptive Learning Rates, ...
  • DLL New Features: Embeddings and Merge layers
  • DLL: Blazing Fast Neural Network Library
  • Comments

    Comments powered by Disqus