ML Playground / LSTM View Notebook

LSTM (Long Short-Term Memory)

An improved RNN architecture that solves the vanishing gradient problem using gates to control information flow.

What is LSTM?

LSTM is a type of Recurrent Neural Network designed to remember information for long sequences. Unlike a regular RNN, LSTMs use a gating mechanism with a dedicated cell state to decide what to remember, what to forget, and what to output at each step.

LSTMs solve the vanishing gradient problem that makes standard RNNs forget earlier inputs in long sequences. This makes them effective for text generation, speech recognition, and time-series forecasting.

Key Concepts

Cell State vs Hidden State

The Three Gates

Forget Gate

Decides what past information to throw away from the cell state.

Input Gate

Decides what new information to add to the cell state.

Output Gate

Decides what information from the cell state to output at the current step.

RNN vs LSTM

FeatureRNNLSTM
MemoryShort-term onlyLong-term (via cell state)
GatesNoneInput, Forget, Output
Long sequencesPoor (vanishing gradient)Excellent
Training stabilityUnstable for deep networksMuch more stable

Code: Prepare Data (Same as RNN)

import numpy as np import tensorflow as tf from tensorflow.keras.preprocessing.text import Tokenizer from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Embedding, LSTM, Dense # Example corpus corpus = [ "hello how are you", "hello how is your day", "hello how are your friends", "hello what are you doing" ] # Tokenize the text tokenizer = Tokenizer() tokenizer.fit_on_texts(corpus) total_words = len(tokenizer.word_index) + 1 # +1 because indexing starts from 1 print("Total unique words:", total_words) # Create input sequences input_sequences = [] for line in corpus: token_list = tokenizer.texts_to_sequences([line])[0] for i in range(1, len(token_list)): n_gram_sequence = token_list[:i+1] input_sequences.append(n_gram_sequence) # Pad sequences to same length max_seq_len = max([len(x) for x in input_sequences]) input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_seq_len, padding='pre')) # Split inputs (X) and labels (y) X = input_sequences[:, :-1] y = input_sequences[:, -1] print("Example X[0]:", X[0], "-> y[0]:", y[0])

Code: Build and Train LSTM

# Build LSTM Model model = Sequential() model.add(Embedding(input_dim=total_words, output_dim=10, input_length=max_seq_len-1)) model.add(LSTM(50, activation='tanh')) # LSTM layer (tanh is default) model.add(Dense(total_words, activation='softmax')) # Predict next word model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # Train the model history = model.fit(X, y, epochs=200, verbose=0) print("Training complete!")

Code: Predict Next Word

def predict_next_word_lstm(model, tokenizer, text_seq, max_seq_len): token_list = tokenizer.texts_to_sequences([text_seq])[0] token_list = pad_sequences([token_list], maxlen=max_seq_len-1, padding='pre') predicted = model.predict(token_list, verbose=0) predicted_word_index = np.argmax(predicted) for word, index in tokenizer.word_index.items(): if index == predicted_word_index: return word # Test seed_text = "hello how is" next_word = predict_next_word_lstm(model, tokenizer, seed_text, max_seq_len) print(f"Input: '{seed_text}' -> Predicted next word: '{next_word}'")

Notice the code is nearly identical to the RNN version -- the only change is replacing SimpleRNN with LSTM. The LSTM uses tanh activation by default (not relu), which is important for its gating mechanism.

When to Use LSTM

Good ForNot Ideal For
Long text sequencesImage data (use CNNs)
Time-series forecastingShort, simple sequences (RNN may suffice)
Speech recognitionWhen training speed is critical (use Transformers)
Language modelingVery long contexts (Transformers handle better)

LSTMGatesCell StateVanishing GradientKeras