LSTM (Long Short-Term Memory)
An improved RNN architecture that solves the vanishing gradient problem using gates to control information flow.
What is LSTM?
LSTM is a type of Recurrent Neural Network designed to remember information for long sequences. Unlike a regular RNN, LSTMs use a gating mechanism with a dedicated cell state to decide what to remember, what to forget, and what to output at each step.
LSTMs solve the vanishing gradient problem that makes standard RNNs forget earlier inputs in long sequences. This makes them effective for text generation, speech recognition, and time-series forecasting.
Key Concepts
Cell State vs Hidden State
- Cell State: A "conveyor belt" that carries important information through the sequence. Information can flow unchanged, preventing forgetting.
- Hidden State: The output of the current step, passed forward to the next step (same as regular RNN).
The Three Gates
Forget Gate
Decides what past information to throw away from the cell state.
Input Gate
Decides what new information to add to the cell state.
Output Gate
Decides what information from the cell state to output at the current step.
RNN vs LSTM
| Feature | RNN | LSTM |
| Memory | Short-term only | Long-term (via cell state) |
| Gates | None | Input, Forget, Output |
| Long sequences | Poor (vanishing gradient) | Excellent |
| Training stability | Unstable for deep networks | Much more stable |
Code: Prepare Data (Same as RNN)
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
# Example corpus
corpus = [
"hello how are you",
"hello how is your day",
"hello how are your friends",
"hello what are you doing"
]
# Tokenize the text
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)
total_words = len(tokenizer.word_index) + 1 # +1 because indexing starts from 1
print("Total unique words:", total_words)
# Create input sequences
input_sequences = []
for line in corpus:
token_list = tokenizer.texts_to_sequences([line])[0]
for i in range(1, len(token_list)):
n_gram_sequence = token_list[:i+1]
input_sequences.append(n_gram_sequence)
# Pad sequences to same length
max_seq_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_seq_len, padding='pre'))
# Split inputs (X) and labels (y)
X = input_sequences[:, :-1]
y = input_sequences[:, -1]
print("Example X[0]:", X[0], "-> y[0]:", y[0])
Code: Build and Train LSTM
# Build LSTM Model
model = Sequential()
model.add(Embedding(input_dim=total_words, output_dim=10, input_length=max_seq_len-1))
model.add(LSTM(50, activation='tanh')) # LSTM layer (tanh is default)
model.add(Dense(total_words, activation='softmax')) # Predict next word
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Train the model
history = model.fit(X, y, epochs=200, verbose=0)
print("Training complete!")
Code: Predict Next Word
def predict_next_word_lstm(model, tokenizer, text_seq, max_seq_len):
token_list = tokenizer.texts_to_sequences([text_seq])[0]
token_list = pad_sequences([token_list], maxlen=max_seq_len-1, padding='pre')
predicted = model.predict(token_list, verbose=0)
predicted_word_index = np.argmax(predicted)
for word, index in tokenizer.word_index.items():
if index == predicted_word_index:
return word
# Test
seed_text = "hello how is"
next_word = predict_next_word_lstm(model, tokenizer, seed_text, max_seq_len)
print(f"Input: '{seed_text}' -> Predicted next word: '{next_word}'")
Notice the code is nearly identical to the RNN version -- the only change is replacing SimpleRNN with LSTM. The LSTM uses tanh activation by default (not relu), which is important for its gating mechanism.
When to Use LSTM
| Good For | Not Ideal For |
| Long text sequences | Image data (use CNNs) |
| Time-series forecasting | Short, simple sequences (RNN may suffice) |
| Speech recognition | When training speed is critical (use Transformers) |
| Language modeling | Very long contexts (Transformers handle better) |
LSTMGatesCell StateVanishing GradientKeras