Recurrent Neural Networks (RNNs) are a class of neural networks designed to handle sequential data. Unlike standard feedforward networks, RNNs maintain an internal state (memory) that is updated as each element of a sequence is processed, allowing information to persist across time steps.
Sequence Architectures
RNNs are flexible and can be adapted to various input-output mapping structures depending on the task:
- One-to-Many: A single input produces a sequence of outputs.
- Example: Image Captioning (Input: Image Output: Sequence of words).
- Many-to-One: A sequence of inputs produces a single output.
- Example: Action Prediction or Sentiment Analysis (Input: Sequence of video frames/words Output: Label).
- Many-to-Many: A sequence of inputs produces a sequence of outputs.
- Example: Video Captioning (Input: Sequence of frames Output: Sequence of words) or Video Classification (Frame-by-frame labeling).
The Recurrence Mechanism
The “key idea” behind an RNN is the use of a recurrence formula applied at every time step (). This allows the network to process a sequence of vectors by maintaining a hidden state .
A. Hidden State Update
The hidden state is updated by combining the previous state with the current input.
Parameters:
- : New state (current hidden state).
- : Old state (hidden state from the previous time step).
- : Input vector at the current time step.
- : A function (often a non-linear activation like or ) with trainable parameters .
Note: the same function and the same set of parameters are used at every time step.
B. Output Generation
After the hidden state is updated, the network can produce an output at that specific time step.
Parameters:
- : Output at time .
- : New state (the updated hidden state).
- : A separate function with its own trainable parameters used to map the hidden state to the output space.
Backpropagation Through Time
Optimizing an RNN requires a specialized version of gradient descent known as Backpropagation Through Time (BPTT).
In its standard form, the model performs a complete forward pass through the entire sequence to calculate a global loss. During the subsequent backward pass, gradients are propagated from the final loss all the way back to the first time step.
While this captures long-range dependencies accurately, it is often computationally prohibitive for long sequences due to the massive memory requirements for storing intermediate states.
To mitigate these resource constraints, researchers often employ Truncated BPTT. This technique involves partitioning the sequence into smaller, manageable chunks.
The model performs a forward and backward pass on a specific chunk to update its weights before moving to the next. Crucially, while the hidden state is carried forward into the next chunk to maintain continuity, the gradient flow is “truncated” at the chunk boundary.
This approximation significantly reduces memory overhead and allows for the training of models on extended temporal data without sacrificing the benefits of sequential learning.
RNN Tradeoffs
RNNs can process input sequences of arbitrary length, and their model size remains constant regardless of input length since the same weights are shared across all time steps—ensuring temporal symmetry in how inputs are processed. In principle, they can leverage information from arbitrarily distant time steps.
However, in practice, recurrent computation tends to be slow due to its sequential nature, and capturing long-range dependencies remains challenging because relevant information from many steps back often becomes inaccessible or diluted over time.
Long Short Term Memory
Long Short Term Memory (LSTMs) are used to alleviate vanishing or exploding gradients when training long sequences by using a gated architecture that regulates the flow of information.
LSTM Mathematics
The operations of an LSTM cell at time step are defined by three primary equations:
1. The Gate Vector
LSTMs compute four internal vectors (gates and candidates) simultaneously by concatenating the previous hidden state and the current input :
- (Input Gate): Decides which new information to store in the cell state.
- (Forget Gate): Decides which information to discard from the previous state.
- (Output Gate): Decides which part of the cell state to output.
- (Cell Candidate): Creates a vector of new candidate values (using ) to be added to the state.
2. The Cell State ()
This is the “long-term memory” of the network. It is updated via a linear combination of the old state and the new candidates:
The use of the Hadamard product (, element-wise multiplication) allows the forget gate to “zero out” specific memories while the input gate adds new ones. This additive update is the secret to preventing vanishing gradients.
3. The Hidden State ()
The hidden state is the “working memory” passed to the next cell and the higher layers. It is a filtered version of the cell state: