Skip to content
📖 Welcome to my knowledge base! Notes on AI/ML, Maths, CS, MBA, Trading, Economics, Health & Self-Help — all in one place.! 🎉 Discover what’s new

Matplotlib

Matplotlib is the foundational visualization library in Python, providing a comprehensive suite of tools for creating static, animated, and interactive visualizations. It offers fine-grained control over every aspect of a figure, from the smallest text element to complex subplot layouts. This tutorial covers everything from basic plotting to advanced customization and integration with other libraries.


Installation and Setup

Installing Matplotlib

pip install matplotlib

Or with conda:

conda install matplotlib

Import Conventions

import matplotlib.pyplot as plt
import numpy as np

# For Jupyter notebooks, enable inline plotting
%matplotlib inline

# For interactive mode in scripts
plt.ion()

Understanding the Architecture

Matplotlib has three main layers:

  • Backend: Handles rendering to different outputs (screen, files, etc.)
  • Artist: All visual elements (lines, text, patches, etc.)
  • Pyplot: A procedural interface that manages figures and axes

Core Concepts

Figure and Axes

  • Figure: The top-level container that holds all plot elements
  • Axes: The actual plotting area with x and y axes
  • Axis: The x-axis and y-axis objects that handle limits, ticks, and labels
# Create figure and axes
fig, ax = plt.subplots()  # Recommended approach

# Alternative: Figure and axes separately
fig = plt.figure()
ax = fig.add_subplot(111)

Explicit vs Implicit Interfaces

Explicit (object-oriented) interface (recommended for complex plots):

fig, ax = plt.subplots()
ax.plot(x, y)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_title('Title')
plt.show()

Implicit (pyplot) interface (quick and easy):

plt.plot(x, y)
plt.xlabel('X Label')
plt.ylabel('Y Label')
plt.title('Title')
plt.show()

The explicit interface is preferred for production code, complex layouts, and when you need fine control.


Basic Plots

Line Plot

The most fundamental plot type for showing trends over continuous data.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, ax = plt.subplots()
ax.plot(x, y)
ax.set_xlabel('x')
ax.set_ylabel('sin(x)')
ax.set_title('Sine Wave')
plt.show()

Multiple Lines

y2 = np.cos(x)

fig, ax = plt.subplots()
ax.plot(x, y, label='sin(x)')
ax.plot(x, y2, label='cos(x)')
ax.legend()
plt.show()

Scatter Plot

For visualizing relationships between two variables.

x = np.random.randn(100)
y = np.random.randn(100)
colors = np.random.randn(100)
sizes = np.random.randint(10, 100, 100)

fig, ax = plt.subplots()
scatter = ax.scatter(x, y, c=colors, s=sizes, alpha=0.5)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Scatter Plot with Color and Size')
plt.show()

Bar Plot

For categorical data comparison.

categories = ['A', 'B', 'C', 'D']
values = [23, 45, 12, 67]

fig, ax = plt.subplots()
ax.bar(categories, values)
ax.set_xlabel('Categories')
ax.set_ylabel('Values')
ax.set_title('Bar Plot')
plt.show()

# Horizontal bar
fig, ax = plt.subplots()
ax.barh(categories, values)

Histogram

For visualizing data distributions.

data = np.random.randn(1000)

fig, ax = plt.subplots()
ax.hist(data, bins=30, edgecolor='black', alpha=0.7)
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')
ax.set_title('Histogram')
plt.show()

Box Plot

For showing summary statistics and outliers.

data = np.random.randn(100, 3)

fig, ax = plt.subplots()
ax.boxplot(data)
ax.set_xlabel('Groups')
ax.set_ylabel('Values')
ax.set_title('Box Plot')
plt.show()

Pie Chart

For showing proportions of a whole.

sizes = [25, 35, 20, 20]
labels = ['A', 'B', 'C', 'D']
explode = [0, 0.1, 0, 0]

fig, ax = plt.subplots()
ax.pie(sizes, labels=labels, explode=explode, autopct='%1.1f%%', 
       startangle=90)
ax.axis('equal')
plt.show()

Advanced Customization

Colors and Styles

Available colors:

  • Named colors: ‘red’, ‘blue’, ‘green’, etc.
  • Hex codes: ‘#FF0000’, ‘#00FF00’
  • RGB tuples: (0.1, 0.2, 0.3)
  • Grayscale: ‘0.5’

Color cycles and styles:

# Set color cycle
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=['red', 'green', 'blue'])

# Use built-in styles
plt.style.available  # View available styles
plt.style.use('ggplot')  # Apply style

Markers and Linestyles

# Line styles
ax.plot(x, y, linestyle='--')  # 'solid', 'dashed', 'dotted', 'dashdot'

# Markers
ax.plot(x, y, marker='o')  # 'o', 's', '^', 'D', '*', '+', etc.

# Combined
ax.plot(x, y, 'ro--')  # red circles with dashed line
ax.plot(x, y, marker='o', color='red', linestyle='--', linewidth=2)

Figure and Axes Customization

# Figure size
fig, ax = plt.subplots(figsize=(10, 6))

# DPI (dots per inch) for resolution
fig, ax = plt.subplots(figsize=(8, 6), dpi=100)

# Titles and labels
ax.set_title('Main Title', fontsize=16, fontweight='bold')
ax.set_xlabel('X Axis', fontsize=12)
ax.set_ylabel('Y Axis', fontsize=12)

# Legends
ax.legend(loc='upper right', framealpha=0.5, fontsize=10)

# Grid
ax.grid(True, linestyle='--', alpha=0.7)

Ticks and Tick Labels

# Custom tick locations
ax.set_xticks([0, 2, 4, 6, 8, 10])
ax.set_yticks([-1, 0, 1])

# Custom tick labels
ax.set_xticklabels(['Zero', 'Two', 'Four', 'Six', 'Eight', 'Ten'])

# Rotate ticks
plt.xticks(rotation=45)
plt.yticks(rotation=0)

# Scientific notation
ax.ticklabel_format(style='sci', axis='both')

Limits and Scaling

# Set limits
ax.set_xlim(0, 10)
ax.set_ylim(-1.5, 1.5)

# Auto-scaling
ax.autoscale()

# Log scales
ax.set_xscale('log')
ax.set_yscale('log')

# Symmetric log scale
ax.set_yscale('symlog')

Subplots and Layouts

Basic Subplots

# 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

# Access individual axes
ax1, ax2, ax3, ax4 = axes.flatten()

ax1.plot(x, y)
ax2.scatter(x, y2)
ax3.bar(categories, values)
ax4.hist(data)

plt.tight_layout()
plt.show()

Shared Axes

# Share x-axis
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)

# Share both axes
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)

GridSpec for Complex Layouts

import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(10, 8))
gs = gridspec.GridSpec(3, 3, figure=fig)

ax1 = fig.add_subplot(gs[0, :])  # Top row, all columns
ax2 = fig.add_subplot(gs[1, :-1])  # Middle row, left 2 columns
ax3 = fig.add_subplot(gs[1:, -1])  # Bottom row, last column
ax4 = fig.add_subplot(gs[2, 0])  # Bottom row, left column
ax5 = fig.add_subplot(gs[2, 1])  # Bottom row, middle column

plt.tight_layout()
plt.show()

Subplots with Different Sizes

fig = plt.figure(figsize=(10, 8))

# Main plot
ax_main = fig.add_axes([0.1, 0.1, 0.6, 0.8])

# Smaller axes inside
ax_right = fig.add_axes([0.75, 0.1, 0.2, 0.8])
ax_top = fig.add_axes([0.1, 0.9, 0.6, 0.1])

ax_main.plot(x, y)
ax_right.scatter(x, y2)
ax_top.bar(categories, values)

Specialized Plot Types

Contour and Contourf

For 2D function visualization.

x = np.linspace(-3, 3, 50)
y = np.linspace(-3, 3, 50)
X, Y = np.meshgrid(x, y)
Z = np.exp(-X**2 - Y**2)

fig, ax = plt.subplots()
contour = ax.contourf(X, Y, Z, levels=20, cmap='viridis')
ax.contour(X, Y, Z, levels=5, colors='black', linewidths=0.5)
plt.colorbar(contour)
plt.show()

3D Plots

from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Surface plot
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)

# Scatter plot
ax.scatter3D(x, y, z, c=z, cmap='viridis')

# Wireframe
ax.plot_wireframe(X, Y, Z, color='black', alpha=0.3)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()

Heatmap

data = np.random.rand(10, 12)

fig, ax = plt.subplots()
im = ax.imshow(data, cmap='hot', interpolation='nearest')
ax.set_xticks(np.arange(12))
ax.set_yticks(np.arange(10))
plt.colorbar(im)
plt.show()

Violin Plot

data = [np.random.randn(100) for _ in range(4)]

fig, ax = plt.subplots()
ax.violinplot(data, showmeans=True, showmedians=True)
ax.set_xticks([1, 2, 3, 4])
ax.set_xticklabels(['A', 'B', 'C', 'D'])
plt.show()

Step Plot

x = np.arange(5)
y = np.array([1, 3, 2, 5, 4])

fig, ax = plt.subplots()
ax.step(x, y, where='mid')
ax.plot(x, y, 'o--', alpha=0.5)
plt.show()

Streamplot

For visualizing vector fields.

x = np.linspace(-3, 3, 30)
y = np.linspace(-3, 3, 30)
X, Y = np.meshgrid(x, y)
U = -1 - X**2 + Y
V = 1 + X - Y**2

fig, ax = plt.subplots()
ax.streamplot(X, Y, U, V, density=1, linewidth=1, arrowsize=1)
plt.show()

Annotations and Text

Adding Text

fig, ax = plt.subplots()
ax.plot(x, y)

# Basic text
ax.text(5, 0, 'Important Point', fontsize=12)

# Text with box
ax.text(5, 0, 'Important Point', 
        bbox=dict(facecolor='yellow', alpha=0.5),
        fontsize=12)

# Math expressions
ax.text(5, 0.5, r'$\sin(x)$', fontsize=14)

# Title and labels with mathematical notation
ax.set_title(r'$\sin(x)$ and $\cos(x)$')
ax.set_xlabel(r'$x$ axis')
ax.set_ylabel(r'$y$ axis')

Annotations

# Point annotation
ax.annotate('Local Max', xy=(np.pi/2, 1), xytext=(np.pi/2, 1.2),
            arrowprops=dict(facecolor='black', shrink=0.05))

# Arrow annotation
ax.annotate('', xy=(0, 0), xytext=(1, 1),
            arrowprops=dict(arrowstyle='->'))

Legends and Colorbars

Advanced Legend Customization

# Multiple columns in legend
ax.legend(ncol=2, loc='upper center', bbox_to_anchor=(0.5, 1.15))

# Legend outside plot
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Custom legend entries
from matplotlib.patches import Patch
handles = [Patch(color='red', label='Red'), 
           Patch(color='blue', label='Blue')]
ax.legend(handles=handles)

# Remove legend
ax.legend().remove()

Colorbar Customization

im = ax.scatter(x, y, c=z, cmap='viridis')
cbar = plt.colorbar(im, ax=ax)

# Customize colorbar
cbar.set_label('Value', fontsize=12)
cbar.ax.tick_params(labelsize=10)
cbar.set_ticks([-2, 0, 2])

# Remove colorbar
cbar.remove()

Working with Text and Fonts

Font Properties

import matplotlib as mpl

# Set font globally
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 12

# Custom font properties
from matplotlib.font_manager import FontProperties
font = FontProperties(family='serif', size=14, weight='bold')
ax.set_title('Title', fontproperties=font)

# Mathematical text
ax.text(0.5, 0.5, r'$\frac{1}{2} \pi \sigma$', fontsize=16)

Saving Figures

Basic Saving

plt.savefig('plot.png')
plt.savefig('plot.pdf')
plt.savefig('plot.svg')

Customizing Saved Figures

# High quality
plt.savefig('plot.png', dpi=300, bbox_inches='tight')

# Transparent background
plt.savefig('plot.png', transparent=True)

# No white space around figure
plt.savefig('plot.png', bbox_inches='tight', pad_inches=0)

# Lossless compression with low file size
plt.savefig('plot.png', dpi=300, bbox_inches='tight',
            facecolor='white', edgecolor='none', format='png')

# Multiple formats
plt.savefig('plot.pdf')
plt.savefig('plot.png')
plt.savefig('plot.eps')

Animations

Basic Animation

from matplotlib.animation import FuncAnimation

fig, ax = plt.subplots()
x = np.linspace(0, 2*np.pi, 100)
line, = ax.plot(x, np.sin(x))

def update(frame):
    line.set_ydata(np.sin(x + frame * 0.1))
    return line,

ani = FuncAnimation(fig, update, frames=100, interval=50)
plt.show()

# Save animation
ani.save('animation.gif', writer='pillow', fps=20)

Saving Animations

from matplotlib.animation import FuncAnimation, FFMpegWriter

ani.save('animation.mp4', writer='ffmpeg', fps=30)
ani.save('animation.gif', writer='pillow', fps=30)

Working with Images

Displaying Images

from PIL import Image

# Load and display image
img = plt.imread('image.jpg')
fig, ax = plt.subplots()
ax.imshow(img)
ax.axis('off')
plt.show()

# Display with colorbar
fig, ax = plt.subplots()
im = ax.imshow(img, cmap='gray')
plt.colorbar(im)

Image Manipulation

# Extract channels
red = img[:, :, 0]
green = img[:, :, 1]
blue = img[:, :, 2]

# Display subplots with images
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(red, cmap='Reds')
axes[1].imshow(green, cmap='Greens')
axes[2].imshow(blue, cmap='Blues')

Advanced Techniques

Twin Axes

For plotting variables with different scales.

fig, ax1 = plt.subplots()
ax1.plot(x, y, color='blue', label='sin(x)')
ax1.set_xlabel('x')
ax1.set_ylabel('sin(x)', color='blue')

ax2 = ax1.twinx()
ax2.plot(x, y2, color='red', label='cos(x)')
ax2.set_ylabel('cos(x)', color='red')

plt.show()

Inset Axes

fig, ax = plt.subplots()
ax.plot(x, y)

# Create inset
inset = ax.inset_axes([0.6, 0.6, 0.3, 0.3])
inset.plot(x, y2)
inset.set_title('Zoomed In')

plt.show()

Broken Axis

fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(10, 6))

ax1.plot(x[:50], y[:50])
ax1.set_xlim(0, 5)

ax2.plot(x[50:], y[50:])
ax2.set_xlim(5, 10)

# Add break marks
ax1.spines.right.set_visible(False)
ax2.spines.left.set_visible(False)
ax1.tick_params(right=False)
ax2.tick_params(left=False)

Logarithmic Scales

x = np.logspace(0, 3, 100)
y = x**2

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.plot(x, y)
ax1.set_title('Linear Scale')

ax2.plot(x, y)
ax2.set_xscale('log')
ax2.set_yscale('log')
ax2.set_title('Log-Log Scale')

Customizing with rcParams

Global Settings

import matplotlib as mpl

# Set global parameters
plt.rcParams['figure.figsize'] = [10, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 12
plt.rcParams['lines.linewidth'] = 2
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

# Reset to defaults
plt.rcParams.update(plt.rcParamsDefault)

Custom Style Sheet

# Create custom style
style = {
    'figure.figsize': (10, 6),
    'font.size': 12,
    'lines.linewidth': 2,
    'axes.grid': True,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12
}
plt.rcParams.update(style)

Event Handling and Interactivity

Mouse Events

def on_click(event):
    print(f'Click at ({event.xdata:.2f}, {event.ydata:.2f})')
    ax.plot(event.xdata, event.ydata, 'ro')
    fig.canvas.draw()

fig, ax = plt.subplots()
ax.plot(x, y)
fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()

Keyboard Events

def on_key(event):
    if event.key == 'r':
        # Reset view
        ax.set_xlim(0, 10)
        ax.set_ylim(-1.5, 1.5)
        fig.canvas.draw()

fig.canvas.mpl_connect('key_press_event', on_key)

Working with Dates

Date Formatting

import matplotlib.dates as mdates
from datetime import datetime

dates = np.array(['2023-01-01', '2023-02-01', '2023-03-01'], dtype='datetime64')
values = [10, 15, 20]

fig, ax = plt.subplots()
ax.plot(dates, values)

# Format date ticks
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.MonthLocator())
plt.xticks(rotation=45)

# Auto-format dates
from matplotlib.dates import AutoDateFormatter, AutoDateLocator
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(AutoDateFormatter(ax.xaxis.get_major_locator()))

Integration with Other Libraries

With Seaborn

import seaborn as sns

# Use seaborn for plotting
sns.set_style('darkgrid')
fig, ax = plt.subplots()
sns.lineplot(x=x, y=y, ax=ax)

# Matplotlib for fine-tuning
ax.set_title('Plot with Seaborn Style')

With Pandas

import pandas as pd

df = pd.DataFrame({'x': x, 'y': y})
fig, ax = plt.subplots()
df.plot(x='x', y='y', ax=ax, style='o-')

With Cartopy for Maps

import cartopy.crs as ccrs

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree())
ax.coastlines()
plt.show()

Performance Optimization

Optimizing for Speed

# Reduce resolution for interactive plots
plt.rcParams['figure.dpi'] = 50

# Use blitting for animations
ani = FuncAnimation(fig, update, blit=True)

# Disable anti-aliasing for faster rendering
ax.plot(x, y, antialiased=False)

# Use simpler lines for many data points
ax.plot(x, y, rasterized=True)

Handling Large Datasets

# Downsample data
step = len(x) // 1000
ax.plot(x[::step], y[::step])

# Use line collections
from matplotlib.collections import LineCollection
segments = np.array([[x[:-1], y[:-1], x[1:], y[1:]]]).T
lc = LineCollection(segments)
ax.add_collection(lc)

Troubleshooting

Common Issues

Missing fonts or LaTeX:

# Use built-in math rendering
plt.rcParams['text.usetex'] = False

# Install LaTeX if needed
# On Ubuntu: sudo apt-get install texlive texlive-latex-extra

Memory issues with animations:

# Clear figures
plt.close('all')

# Save animations in chunks
for i in range(10):
    fig = plt.figure()
    # ... plotting code ...
    plt.savefig(f'frame_{i:03d}.png')
    plt.close(fig)

Tight layout not working:

# Use constrained layout instead
fig, ax = plt.subplots(constrained_layout=True)
plt.tight_layout()

# Or set manually
plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)

Best Practices

Code Organization

# Create reusable plotting functions
def create_line_plot(x, y, title='', xlabel='', ylabel='', **kwargs):
    fig, ax = plt.subplots()
    ax.plot(x, y, **kwargs)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True)
    return fig, ax

# Use consistent styling
plt.style.use('seaborn-v0_8-whitegrid')

Accessibility

# Use colorblind-friendly palettes
from matplotlib.colors import to_rgba
colors = ['#0072B2', '#D55E00', '#009E73', '#CC79A7', '#F0E442']

# Add labels and legends
ax.legend()
ax.set_xlabel('X Axis')
ax.set_ylabel('Y Axis')
ax.set_title('Descriptive Title')

# Use patterns for black-and-white prints
ax.bar(categories, values, hatch='//')

Production Considerations

# Consistent size and resolution
fig, ax = plt.subplots(figsize=(8, 6), dpi=300)

# High-quality output
plt.savefig('final_plot.png', dpi=300, bbox_inches='tight',
            facecolor='white', edgecolor='none')

# Use vector formats when possible
plt.savefig('final_plot.pdf', bbox_inches='tight')
Last updated on