ML Playground / Decision Tree View Notebook

Decision Tree

A tree-structured model that splits data based on feature conditions for classification and regression.

What is a Decision Tree?

A Decision Tree is a supervised learning algorithm that splits data into branches based on feature conditions, forming a tree structure. It works for both classification (predict a category) and regression (predict a number). The result is an interpretable, flowchart-like model.

Decision Trees are one of the most interpretable ML models. You can literally read the decision path like a flowchart: "If income > 15 LPA AND age > 35, then predict Yes."

Key Concepts

Impurity Measures

The algorithm picks the feature that produces the "purest" split. Purity is measured by:

Gini Impurity

Measures probability of misclassifying a random sample. Gini = 1 - sum(p_i^2). Used for classification. Lower = purer.

Entropy (Information Gain)

Measures disorder/uncertainty. Entropy = -sum(p_i * log2(p_i)). Used for classification. Lower = purer.

How It Works

Algorithm Steps
  1. Select the best feature for splitting (using Gini or Entropy)
  2. Split the data into subgroups based on that feature's condition
  3. Repeat recursively for each subgroup
  4. Stop when: all data in a node belongs to one class (pure), depth limit reached, or further splits don't improve accuracy

Example: House Purchase Prediction

[Root: Income] | ----------------------- | | Income < 15 Income >= 15 | | No [Check Age] | | Age < 35 Age >= 35 | | No Yes

Code Implementation

import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, plot_tree from sklearn.metrics import accuracy_score, classification_report import matplotlib.pyplot as plt # Dataset data = { "Age": [22, 25, 30, 35, 40, 50, 55], "Income_LPA": [8, 10, 15, 12, 25, 20, 35], "Buys_House": [0, 0, 0, 1, 1, 0, 1] } df = pd.DataFrame(data) # Features and target X = df[["Age", "Income_LPA"]] y = df["Buys_House"] # Train-test split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Train Decision Tree dt_model = DecisionTreeClassifier(criterion="gini", max_depth=5, random_state=2) dt_model.fit(X_train, y_train) # Predict and evaluate y_pred = dt_model.predict(X_test) print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}") print("Classification Report:\n", classification_report(y_test, y_pred, zero_division=0)) # Visualize the tree plt.figure(figsize=(20, 16)) plot_tree(dt_model, feature_names=["Age", "Income_LPA"], class_names=["No", "Yes"], filled=True) plt.show()

Pros and Cons

AdvantagesDisadvantages
Easy to understand and interpretOverfits easily if tree is too deep
Handles numeric and categorical dataSensitive to noisy data
Works with non-linear relationshipsSmall changes in data can change the entire tree
No feature scaling requiredGreedy splits may not find global optimum

Always set max_depth or min_samples_split to prevent overfitting. An unconstrained tree will memorize the training data.

Key Parameters

Classification Regression Supervised Tree-based Interpretable