Decision Tree
A tree-structured model that splits data based on feature conditions for classification and regression.
What is a Decision Tree?
A Decision Tree is a supervised learning algorithm that splits data into branches based on feature conditions, forming a tree structure. It works for both classification (predict a category) and regression (predict a number). The result is an interpretable, flowchart-like model.
Decision Trees are one of the most interpretable ML models. You can literally read the decision path like a flowchart: "If income > 15 LPA AND age > 35, then predict Yes."
Key Concepts
- Root Node — the starting point containing the entire dataset
- Decision Nodes — intermediate nodes where conditions are checked
- Leaf Nodes — final nodes that give the prediction
- Splitting — dividing data into subgroups based on the best feature
Impurity Measures
The algorithm picks the feature that produces the "purest" split. Purity is measured by:
Gini Impurity
Measures probability of misclassifying a random sample. Gini = 1 - sum(p_i^2). Used for classification. Lower = purer.
Entropy (Information Gain)
Measures disorder/uncertainty. Entropy = -sum(p_i * log2(p_i)). Used for classification. Lower = purer.
How It Works
Algorithm Steps
- Select the best feature for splitting (using Gini or Entropy)
- Split the data into subgroups based on that feature's condition
- Repeat recursively for each subgroup
- Stop when: all data in a node belongs to one class (pure), depth limit reached, or further splits don't improve accuracy
Example: House Purchase Prediction
[Root: Income]
|
-----------------------
| |
Income < 15 Income >= 15
| |
No [Check Age]
| |
Age < 35 Age >= 35
| |
No Yes
Code Implementation
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
# Dataset
data = {
"Age": [22, 25, 30, 35, 40, 50, 55],
"Income_LPA": [8, 10, 15, 12, 25, 20, 35],
"Buys_House": [0, 0, 0, 1, 1, 0, 1]
}
df = pd.DataFrame(data)
# Features and target
X = df[["Age", "Income_LPA"]]
y = df["Buys_House"]
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train Decision Tree
dt_model = DecisionTreeClassifier(criterion="gini", max_depth=5, random_state=2)
dt_model.fit(X_train, y_train)
# Predict and evaluate
y_pred = dt_model.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}")
print("Classification Report:\n", classification_report(y_test, y_pred, zero_division=0))
# Visualize the tree
plt.figure(figsize=(20, 16))
plot_tree(dt_model, feature_names=["Age", "Income_LPA"], class_names=["No", "Yes"], filled=True)
plt.show()
Pros and Cons
| Advantages | Disadvantages |
| Easy to understand and interpret | Overfits easily if tree is too deep |
| Handles numeric and categorical data | Sensitive to noisy data |
| Works with non-linear relationships | Small changes in data can change the entire tree |
| No feature scaling required | Greedy splits may not find global optimum |
Always set max_depth or min_samples_split to prevent overfitting. An unconstrained tree will memorize the training data.
Key Parameters
- criterion — "gini" or "entropy" (splitting quality measure)
- max_depth — Maximum tree depth (controls overfitting)
- min_samples_split — Minimum samples needed to split a node
- min_samples_leaf — Minimum samples required in each leaf
Classification Regression Supervised Tree-based Interpretable