Plotting for ML & Data Science — Complete Reference
Core Concept

Plotting for Data Science & ML

Data without visualization is just noise. This guide covers everything from the foundational mechanics of Matplotlib to the statistical power of Seaborn and the interactivity of Plotly. Master these, and you will excel in any ML or DS role.

Matplotlib Seaborn Plotly Data Visualization
PART I

Introduction

What is a plot?

A plot is a graphical representation of data. It translates raw numbers into visual geometry (dots, lines, bars) projected onto coordinate axes. In Machine Learning, plotting is not optional—it is the only way to detect anomalies, understand feature distributions, visualize decision boundaries, and track a model's loss during training.

The Big Three Libraries

In the Python data science ecosystem, plotting is dominated by three main libraries, each serving a distinct purpose:

1. Matplotlib
The absolute foundation. It is low-level, extremely customizable, and powerful. You must know Matplotlib to succeed in DS, as almost everything else is built on top of it.
2. Seaborn
Built on top of Matplotlib, Seaborn provides a high-level interface for drawing attractive and informative statistical graphics (like heatmaps and pairplots). It requires significantly less code than Matplotlib.
3. Plotly
Unlike Matplotlib, Plotly renders in the browser using JavaScript. It allows for highly interactive plots (zooming, hovering, panning) which is essential for presenting data to stakeholders or exploring dense datasets.
PART II

The Core Plot Types

1. Scatter Plot

A scatter plot uses dots to represent values for two different numeric variables. The position of each dot on the horizontal and vertical axis indicates values for an individual data point.

ML Use Case: Identifying clusters, observing correlation between two features, and plotting Principal Component Analysis (PCA) reductions.

2. Line Graph

A line graph connects individual data points with line segments. It displays quantitative values over a continuous interval or time period.

ML Use Case: Tracking Loss and Accuracy over training epochs. Time series forecasting.

3. Bar Graph

A bar graph presents categorical data with rectangular bars with heights or lengths proportional to the values that they represent.

ML Use Case: Visualizing Feature Importance (which features matter most to your Random Forest), or showing class imbalance in your dataset.

4. Histogram

A histogram is an approximate representation of the distribution of numerical data. It divides the entire range of values into a series of intervals (bins) and counts how many values fall into each bin.

ML Use Case: Understanding the distribution of a single feature (e.g. checking if a feature is normally distributed or heavily skewed).
Matplotlib · Interactive Explorer

Plot Types, Visualized

Click any plot type to see it rendered live

Scatter Plot
Points: 80
Python · Matplotlib

              
PART III

Matplotlib Mechanics

How to Plot Data

Matplotlib has two interfaces: the Pyplot (state-based) interface, and the Object-Oriented (OO) interface. In ML, you should always use the Object-Oriented interface (`fig, ax = plt.subplots()`) as it gives you explicit control over your axes.

python
import matplotlib.pyplot as plt
import numpy as np

# 1. Generate some data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# 2. Create Figure and Axes (Object-Oriented API)
fig, ax = plt.subplots(figsize=(8, 4))

# 3. Plot the data
ax.plot(x, y)

# 4. Display
plt.show()

How to Label a Plot

Never leave a plot unlabeled. You must set the title, x-axis label, y-axis label, and include a legend if multiple elements are present.

python
ax.plot(x, y, label='Sine Wave')

# Labeling the axes and title
ax.set_title('Signal Strength over Time')
ax.set_xlabel('Time (seconds)')
ax.set_ylabel('Amplitude')

# Show the legend (uses the label parameter from .plot)
ax.legend(loc='upper right')

How to Scale an Axis

Often in Data Science, data spans multiple orders of magnitude (like income or exponential growth). You will need to scale your axis or set absolute limits.

python
# Setting explicit limits for the X and Y axes
ax.set_xlim(0, 10)
ax.set_ylim(-1.5, 1.5)

# Changing the scale to logarithmic (crucial for skewed ML data)
ax.set_xscale('log')
ax.set_yscale('linear')

How to Plot Multiple Sets of Data

To plot multiple sets of data on the same graph, simply call the plot function multiple times on the same `ax` object before calling `plt.show()`.

python
fig, ax = plt.subplots()

# Plotting multiple lines on the same axes
ax.plot(epochs, train_loss, color='blue', label='Training Loss')
ax.plot(epochs, val_loss, color='red', linestyle='--', label='Validation Loss')

ax.set_title('Model Training History')
ax.legend()
plt.show()
PART IV

Seaborn (Statistical Data Viz)

Seaborn shines when you need to visualize complex statistical relationships in pandas DataFrames with minimal code. It automatically handles labeling and provides stunning color palettes.

ML Highlight: The Correlation Heatmap
Before training any model, Data Scientists use Seaborn's `heatmap` to check feature correlation. If two features are perfectly correlated (multicollinearity), one should be dropped!
python
import seaborn as sns

# Assume 'df' is a Pandas DataFrame with housing data
corr_matrix = df.corr()

# Plotting a heatmap of correlations
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm')
plt.title('Feature Correlation Heatmap')
plt.show()
Feature Correlation Heatmap in Seaborn
Output of sns.heatmap()
PART V

Plotly (Interactive Visualizations)

When presenting models to non-technical stakeholders, static images often fall short. Plotly allows the user to hover over data points to see exact values, zoom into dense clusters, and toggle datasets on and off via the legend.

python
import plotly.express as px

# Plotly Express allows building interactive plots in one line of code
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species", 
                 title="Interactive Iris Dataset")
fig.show()

Interactive Plotly Demo: Loss Curves

Hover over the lines to see exact loss values per epoch. You can click the legend items to toggle the training or validation loss visibility.

PART VI

Essential ML Workflows

Learning Curves

A Learning Curve is a line plot showing the model's loss or accuracy on both the training and validation datasets over time (epochs) or over varying amounts of training data. It is the primary tool for diagnosing bias (underfitting) and variance (overfitting).

Overfitting Indicator
If your training loss continues to decrease, but your validation loss begins to increase, your model is memorizing the training data. You should apply early stopping, dropout, or regularization.
python
import matplotlib.pyplot as plt

# Plotting a Learning Curve with Matplotlib
fig, ax = plt.subplots(figsize=(8, 5))

ax.plot(epochs, train_loss, label='Training Loss', color='blue')
ax.plot(epochs, val_loss, label='Validation Loss', color='red', linestyle='--')

ax.set_title('Model Learning Curve')
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.legend()
plt.show()

The Confusion Matrix

For classification tasks, accuracy is rarely enough. A Confusion Matrix visualizes True Positives, True Negatives, False Positives, and False Negatives. In ML, we usually use Seaborn's heatmap to make this readable.

python
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Calculate confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot using Seaborn
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

ROC Curve & AUC

The Receiver Operating Characteristic (ROC) curve plots the True Positive Rate against the False Positive Rate at various threshold settings. The Area Under the Curve (AUC) provides an aggregate measure of performance across all possible classification thresholds.

ML Rule of Thumb: An AUC of 0.5 means the model is randomly guessing. An AUC of 1.0 means the model is perfect.

Feature Importance

Tree-based models (like Random Forest or XGBoost) provide "feature importances". A horizontal bar chart is the industry standard for visualizing which features drove the model's predictions.

python
# Assuming 'importances' is a NumPy array and 'features' is a list of strings
fig, ax = plt.subplots(figsize=(10, 6))

# Horizontal bar chart
ax.barh(features, importances, color='#1B4F72')
ax.set_xlabel('Relative Importance')
ax.set_title('Random Forest Feature Importance')

# Invert y-axis to show the most important feature at the top
ax.invert_yaxis()
plt.show()
PART VII

Exploratory Data Analysis (EDA)

Before training a model, an ML Engineer must understand the shape of the data. Seaborn is the absolute standard for EDA plotting.

Box Plots for Outlier Detection

Box plots visualize the median, quartiles, and most importantly, the outliers of a distribution. Outliers often need to be capped, scaled, or removed before feeding data into linear models or neural networks.

python
# Using Seaborn to plot a box plot of house prices by neighborhood
plt.figure(figsize=(10, 6))
sns.boxplot(x='neighborhood', y='price', data=df)
plt.title('Distribution of House Prices (Detecting Outliers)')
plt.xticks(rotation=45)
plt.show()

Violin Plots

Violin plots combine a Box Plot with a Kernel Density Estimation (KDE) plot. They show the exact probability density of the data at different values, giving you much more information than a simple box plot.

python
# Violin plots handle multimodal distributions (data with multiple peaks) gracefully
sns.violinplot(x='species', y='sepal_length', data=iris_df, palette='muted')
plt.show()

Pairplots

When you first receive a dataset, a Pairplot allows you to plot pairwise relationships across the entire dataframe at once. It plots scatter plots for joint relationships and histograms for univariate distributions.

python
# Warning: This is extremely computationally heavy on datasets with >20 columns.
# We color code the dots using the 'hue' parameter.
sns.pairplot(df, hue='target_class', corner=True)
plt.show()
PART VIII

Dimensionality Reduction

In modern ML, especially NLP and Computer Vision, datasets often have hundreds or thousands of dimensions (embeddings). Humans can only see in 3 dimensions. We use techniques like PCA, t-SNE, or UMAP to compress these dimensions down to 2D, and then plot them using Scatter Plots.

Visualizing Latent Spaces
If you take a neural network that classifies dog breeds, and you run t-SNE on its second-to-last layer, you will see clusters of similar dogs (e.g., all Terriers clustered together). This scatter plot proves that your model has successfully learned the abstract relationships in the data.
python
from sklearn.manifold import TSNE

# 1. Compress 512 dimensions down to 2 dimensions
tsne = TSNE(n_components=2, random_state=42)
X_2d = tsne.fit_transform(X_high_dim)

# 2. Plot the 2D latent space
plt.figure(figsize=(10, 8))
# X_2d[:, 0] grabs the first dimension, X_2d[:, 1] grabs the second
scatter = plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y_labels, cmap='viridis', alpha=0.6)

plt.colorbar(scatter, label='Class Label')
plt.title('t-SNE Projection of High-Dimensional Embeddings')
plt.show()
PART IX

The Decision Matrix

A Data Scientist's job is not just to draw plots, but to use them to make architectural decisions. Use these matrices to determine exactly which graph to reach for based on your current goal.

Data Science: Exploratory Data Analysis (EDA)

Visual Example Analysis Goal Recommended Graph Decision Value
Find anomalies and understand spread of a single continuous feature Box Plot / Violin Plot Should I drop these outliers, cap them, or leave them for the model to learn? Are there multiple peaks in the distribution?
Check for Multicollinearity (features that overlap) Seaborn Correlation Heatmap If two features have a correlation > 0.8, I should drop one of them before training a linear model to prevent unstable weights.
Understand Class Imbalance Bar Chart / Count Plot If 99% of data is Class A and 1% is Class B, I must use SMOTE, class weights, or stratified splitting. Accuracy will be a broken metric here.
Quickly find relationships across the entire dataset Pairplot (Scatter Matrix) Are there natural linear relationships? Are there distinct clusters forming? Do I need a non-linear model?

Machine Learning: Model Evaluation & Explainability

Visual Example Evaluation Goal Recommended Graph Decision Value
Diagnose Model Overfitting or Underfitting Learning Curves (Loss vs Epochs) If validation loss goes up while training loss goes down, I must stop training earlier or add Regularization/Dropout.
Evaluate a Binary Classifier's overall effectiveness ROC Curve & AUC Helps me decide the exact probability threshold to use to maximize business value (e.g. catching all fraud vs. minimizing false alarms).
TP FP FN TN Analyze where a model is making mistakes Confusion Matrix Tells me exactly which classes the model is confusing. E.g. Is it confusing 'Cats' for 'Dogs', or 'Cats' for 'Cars'?
Explain a Black-Box Model to Stakeholders Feature Importance / SHAP Plot Tells the business why the model made its decision. Which features drove the prediction? Which features are useless and can be dropped?
Comments

Comments

Loading comments...