What Are RNNs?
Recurrent Neural Networks (RNNs) are a class of artificial neural networks designed to recognize patterns in sequences of data, such as text, time series, and speech.
Unlike traditional feedforward neural networks, RNNs are unique because they process sequential data one step at a time while maintaining a hidden state that captures information about previous steps. Here’s a breakdown:
- Sequential Input Processing:
- In RNNs, the input is not processed all at once. Instead, the network takes one element of the sequence at a time.
- Example: Processing the sentence "I love AI" involves feeding the words one by one to the network.
- Hidden State:
- At each time step, RNNs update their hidden state based on the current input and the previous hidden state.
- The hidden state acts as a
memory
that accumulates information over the sequence.
How RNNs Work
At the heart of an RNN is the hidden state , which serves as the
memory
. For a given inputthe RNN updates the hidden state at each time step using:
Here:
- is the activation function, typically or .
- and are weight matrices.
- is a bias term.
The output is often derived by applying another weight matrix to the hidden state:
where is typically a softmax function for classification tasks.
Loss Calculation in RNNs
The loss in RNNs is calculated over the entire sequence. For a sequence of predictions:
and corresponding ground truth labels:
the total loss is:
Common Loss Functions
- Cross-Entropy Loss (for classification tasks):
Where:
- : Number of classes.
- : One-hot encoded true label for class at time step .
- : Predicted probability for class at time step .
- Mean Squared Error (MSE) (for regression tasks):
Backpropagation Through Time (BPTT)
Training an RNN involves adjusting its weights , and to minimize the loss. This is done using backpropagation through time (BPTT), an extension of standard backpropagation designed for sequential models.
Step-by-Step Process of BPTT
- Forward Pass:
- Compute the hidden states for .
- Calculate the outputs and the corresponding loss for each time step.
- Backward Pass:
- Compute gradients of the loss with respect to the outputs at each time step.
- Propagate these gradients backward through the sequence to update the weights.
- Weight Updates:
- Gradients are accumulated over all time steps and used to update the weights using gradient descent or a similar optimization algorithm.
Gradient Computation
For each weight matrix, gradients are computed using the chain rule:
- Gradient w.r.t. Output Weights ():
- Gradient w.r.t. Hidden-to-Hidden Weights (): The gradients flow back through all previous time steps, making it computationally expensive.
This requires computing the contribution of each previous time step to the current state , leading to gradient accumulation.
- Gradient w.r.t. Input-to-Hidden Weights ():
Challenges in BPTT
- Vanishing Gradients:
- When the gradients are multiplied repeatedly during backpropagation, they can become very small, effectively halting learning.
- Solution: Use architectures like Long Short-Term Memory (LSTM) or Gated Recurrent Units (GRUs).
- Exploding Gradients:
- In some cases, gradients grow exponentially, causing numerical instability.
- Solution: Apply gradient clipping to cap the gradients during training.
Key Insights
- RNNs calculate loss for sequential outputs, summing up individual losses for each time step.
- Backpropagation through time (BPTT) propagates gradients through all time steps, making it computationally intensive and sensitive to long-term dependencies.
- Advanced techniques like LSTMs and GRUs help mitigate the vanishing gradient problem, making training more effective for longer sequences.
Understanding loss calculation and BPTT provides a solid foundation for exploring and designing more sophisticated sequential models.
A simple RNN implementation using Pytorch
import torch import torch.nn as nn class SimpleRNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleRNN, self).__init__() self.hidden_size = hidden_size # Input to hidden layer self.i2h = nn.Linear(input_size + hidden_size, hidden_size) # Hidden to output layer self.h2o = nn.Linear(hidden_size, output_size) self.activation = nn.Tanh() # You can use other activations like ReLU def forward(self, input, hidden): combined = torch.cat((input, hidden), 1) # Concatenate input and previous hidden state hidden = self.activation(self.i2h(combined)) # Update hidden state output = self.h2o(hidden) # Generate output return output, hidden def init_hidden(self, batch_size): return torch.zeros(batch_size, self.hidden_size) # Example usage: input_size = 10 # Size of input features hidden_size = 20 # Number of hidden units output_size = 1 # Size of output (e.g., for regression) sequence_length = 5 # Length of the input sequence batch_size = 3 # Number of sequences in a batch # Create the RNN model rnn = SimpleRNN(input_size, hidden_size, output_size) # Sample input data (batch_size, sequence_length, input_size) input_data = torch.randn(batch_size, sequence_length, input_size) # Initialize hidden state hidden = rnn.init_hidden(batch_size) # Iterate through the sequence for t in range(sequence_length): input_t = input_data[:, t, :] # Get input at time step t output, hidden = rnn(input_t, hidden) # Pass the current input and the previous hidden state print("Output shape:", output.shape) # Should be (batch_size, output_size) print("Hidden shape:", hidden.shape) # Should be (batch_size, hidden_size) # Example with a sequence input at once (more typical usage) class SimpleRNNSequence(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleRNNSequence, self).__init__() self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) # use the built in RNN self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): # x shape: (batch_size, seq_len, input_size) output, _ = self.rnn(x) # output shape: (batch_size, seq_len, hidden_size) # We usually take the output of the last time step for classification/regression output = self.fc(output[:, -1, :]) # output shape: (batch_size, output_size) return output rnn_seq = SimpleRNNSequence(input_size, hidden_size, output_size) output_seq = rnn_seq(input_data) print("Sequence output shape:", output_seq.shape) # Should be (batch_size, output_size)