Lesson 14 - Creating Subplots
Multi-Panel Figures
You have created single plots—one chart per figure. Now you will learn to create subplots—multiple charts in a single figure arranged in a grid.
By the end of this lesson, you will be able to:
- Use
plt.subplots()to create figure and axes objects - Arrange plots in grids (rows × columns)
- Access individual subplot axes
- Share x or y axes across subplots
- Adjust spacing between subplots
- Create professional multi-panel figures
- Build comparison dashboards
Subplots are essential for comparing data side-by-side and creating comprehensive visualizations.
Understanding Figure and Axes
The Object-Oriented Approach
Before (simple plotting):
import matplotlib.pyplot as plt
plt.plot([1, 2, 3], [1, 4, 9])
plt.show()Now (object-oriented):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
plt.show()What changed:
fig→ the entire figure (canvas)ax→ the axes (the plotting area)- Plot on
axinstead of usingplt
Concept diagram:
┌─────────────────────────── Figure ────────────────────────────┐
│ │
│ Title of the Figure │
│ │
│ ┌─────────────────────── Axes ───────────────────────────┐ │
│ │ │ │
│ │ Y-axis │ │
│ │ ^ │ │
│ │ │ Plot area (data visualization) │ │
│ │ │ │ │
│ │ └──────────────> X-axis │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘Creating Simple Subplots
1 Row, 2 Columns
import pandas as pd
import matplotlib.pyplot as plt
bikes = pd.read_csv('day.csv')
bikes['dteday'] = pd.to_datetime(bikes['dteday'])
# Create 1 row, 2 columns of subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left plot: Total rentals
axes[0].plot(bikes['dteday'], bikes['cnt'], color='steelblue', linewidth=1)
axes[0].set_xlabel('Date')
axes[0].set_ylabel('Total Rentals')
axes[0].set_title('Total Daily Rentals')
axes[0].grid(True, alpha=0.3)
# Right plot: Casual vs Registered
axes[1].plot(bikes['dteday'], bikes['casual'], label='Casual', linewidth=1)
axes[1].plot(bikes['dteday'], bikes['registered'], label='Registered', linewidth=1)
axes[1].set_xlabel('Date')
axes[1].set_ylabel('Rentals')
axes[1].set_title('Casual vs Registered Users')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()Key points:
plt.subplots(1, 2)→ 1 row, 2 columnsaxes[0]→ left plotaxes[1]→ right plotplt.tight_layout()→ adjust spacing automatically
2 Rows, 1 Column
import pandas as pd
import matplotlib.pyplot as plt
bikes = pd.read_csv('day.csv')
bikes['dteday'] = pd.to_datetime(bikes['dteday'])
# Create 2 rows, 1 column
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
# Top plot: Rentals
axes[0].plot(bikes['dteday'], bikes['cnt'], color='darkgreen', linewidth=1)
axes[0].set_ylabel('Total Rentals')
axes[0].set_title('Daily Bike Rentals')
axes[0].grid(True, alpha=0.3)
# Bottom plot: Temperature
axes[1].plot(bikes['dteday'], bikes['temp'], color='red', linewidth=1)
axes[1].set_xlabel('Date')
axes[1].set_ylabel('Normalized Temperature')
axes[1].set_title('Daily Temperature')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()Observation: Vertical stacking good for comparing time series with same x-axis.
2x2 Grid Layout
Four-Panel Dashboard
import pandas as pd
import matplotlib.pyplot as plt
bikes = pd.read_csv('day.csv')
bikes['dteday'] = pd.to_datetime(bikes['dteday'])
# Create 2x2 grid
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Top-left: Rentals over time
axes[0, 0].plot(bikes['dteday'], bikes['cnt'], color='steelblue', linewidth=1)
axes[0, 0].set_ylabel('Total Rentals')
axes[0, 0].set_title('Daily Rentals Over Time')
axes[0, 0].grid(True, alpha=0.3)
# Top-right: Rental distribution
axes[0, 1].hist(bikes['cnt'], bins=30, color='coral', edgecolor='black', alpha=0.7)
axes[0, 1].set_xlabel('Daily Rentals')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Rental Distribution')
axes[0, 1].grid(True, alpha=0.3, axis='y')
# Bottom-left: Temperature vs Rentals
axes[1, 0].scatter(bikes['temp'], bikes['cnt'], alpha=0.5, s=30, color='green')
axes[1, 0].set_xlabel('Normalized Temperature')
axes[1, 0].set_ylabel('Total Rentals')
axes[1, 0].set_title('Temperature vs Rentals')
axes[1, 0].grid(True, alpha=0.3)
# Bottom-right: Seasonal averages
season_avg = bikes.groupby('season')['cnt'].mean()
season_names = ['Spring', 'Summer', 'Fall', 'Winter']
axes[1, 1].bar(season_names, season_avg.values, color='orange', edgecolor='black')
axes[1, 1].set_ylabel('Average Daily Rentals')
axes[1, 1].set_title('Average Rentals by Season')
axes[1, 1].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()Accessing 2D grid:
axes[row, col]axes[0, 0]→ top-leftaxes[0, 1]→ top-rightaxes[1, 0]→ bottom-leftaxes[1, 1]→ bottom-right
Sharing Axes
Shared X-Axis
import pandas as pd
import matplotlib.pyplot as plt
bikes_hour = pd.read_csv('hour.csv')
# Calculate hourly averages
hourly_casual = bikes_hour.groupby('hr')['casual'].mean()
hourly_registered = bikes_hour.groupby('hr')['registered'].mean()
# Create subplots with shared x-axis
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# Top: Casual users
axes[0].plot(hourly_casual.index, hourly_casual.values,
color='skyblue', linewidth=2, marker='o', markersize=4)
axes[0].set_ylabel('Average Casual Rentals')
axes[0].set_title('Hourly Patterns: Casual Users')
axes[0].grid(True, alpha=0.3)
# Bottom: Registered users
axes[1].plot(hourly_registered.index, hourly_registered.values,
color='coral', linewidth=2, marker='o', markersize=4)
axes[1].set_xlabel('Hour of Day')
axes[1].set_ylabel('Average Registered Rentals')
axes[1].set_title('Hourly Patterns: Registered Users')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()Benefits of sharex=True:
- Both plots use same x-axis range
- X-axis labels only on bottom plot
- Aligned for easy comparison
- Zooming one plot zooms both
Shared Y-Axis
import pandas as pd
import matplotlib.pyplot as plt
bikes = pd.read_csv('day.csv')
# Weekend vs Workingday
weekend_rentals = bikes[bikes['workingday'] == 0]['cnt']
workday_rentals = bikes[bikes['workingday'] == 1]['cnt']
# Create subplots with shared y-axis
fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)
# Left: Weekend distribution
axes[0].hist(weekend_rentals, bins=30, color='lightblue', edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Daily Rentals')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Rental Distribution: Weekends')
axes[0].grid(True, alpha=0.3, axis='y')
# Right: Workday distribution
axes[1].hist(workday_rentals, bins=30, color='lightgreen', edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Daily Rentals')
axes[1].set_title('Rental Distribution: Workdays')
axes[1].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()Benefits of sharey=True:
- Same y-axis scale for both plots
- Y-axis labels only on left plot
- Direct comparison of distributions
Flattening Axes for Loops
Using axes.flat
import pandas as pd
import matplotlib.pyplot as plt
bikes = pd.read_csv('day.csv')
# Variables to plot
variables = ['temp', 'atemp', 'hum', 'windspeed']
titles = ['Temperature', 'Feels-Like Temp', 'Humidity', 'Wind Speed']
colors = ['red', 'orange', 'blue', 'green']
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Flatten 2D array to 1D for easy iteration
for ax, var, title, color in zip(axes.flat, variables, titles, colors):
ax.scatter(bikes[var], bikes['cnt'], alpha=0.5, s=30, color=color)
ax.set_xlabel(title)
ax.set_ylabel('Daily Rentals')
ax.set_title(f'{title} vs Rentals')
ax.grid(True, alpha=0.3)
# Add correlation
corr = bikes[var].corr(bikes['cnt'])
ax.text(0.05, 0.95, f'r = {corr:.3f}',
transform=ax.transAxes, fontsize=11, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
plt.tight_layout()
plt.show()axes.flat advantage: Iterate over 2D grid as if it were 1D list.
Figure-Level Titles
Super Title
import pandas as pd
import matplotlib.pyplot as plt
bikes = pd.read_csv('day.csv')
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# Plot by season
for season, ax in zip([1, 2, 3], axes):
season_data = bikes[bikes['season'] == season]
ax.scatter(season_data['temp'], season_data['cnt'], alpha=0.6, s=30)
ax.set_xlabel('Temperature')
ax.set_ylabel('Daily Rentals')
season_name = ['Spring', 'Summer', 'Fall'][season - 1]
ax.set_title(season_name)
ax.grid(True, alpha=0.3)
# Add figure-level title
fig.suptitle('Temperature vs Rentals by Season', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()fig.suptitle() adds a title above all subplots.
Custom Spacing
Adjusting Layout
import pandas as pd
import matplotlib.pyplot as plt
bikes_hour = pd.read_csv('hour.csv')
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Create plots
hourly_pattern = bikes_hour.groupby('hr')['cnt'].mean()
axes[0, 0].plot(hourly_pattern.index, hourly_pattern.values, linewidth=2)
axes[0, 0].set_title('Hourly Pattern')
axes[0, 1].hist(bikes_hour['cnt'], bins=40, edgecolor='black', alpha=0.7)
axes[0, 1].set_title('Rental Distribution')
axes[1, 0].scatter(bikes_hour['temp'], bikes_hour['cnt'], alpha=0.3, s=5)
axes[1, 0].set_title('Temperature vs Rentals')
axes[1, 1].scatter(bikes_hour['hum'], bikes_hour['cnt'], alpha=0.3, s=5)
axes[1, 1].set_title('Humidity vs Rentals')
# Adjust spacing manually
plt.subplots_adjust(hspace=0.3, wspace=0.3)
plt.show()Parameters:
hspace: vertical spacing (height)wspace: horizontal spacing (width)- Values are fractions of subplot size
Alternative: plt.tight_layout() does this automatically!
Practical Example: Comprehensive Analysis
Weather Analysis Dashboard
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
bikes = pd.read_csv('day.csv')
bikes['dteday'] = pd.to_datetime(bikes['dteday'])
# Create comprehensive figure
fig = plt.figure(figsize=(16, 12))
# Grid: 3 rows, 2 columns
ax1 = plt.subplot(3, 2, 1)
ax2 = plt.subplot(3, 2, 2)
ax3 = plt.subplot(3, 2, 3)
ax4 = plt.subplot(3, 2, 4)
ax5 = plt.subplot(3, 2, 5)
ax6 = plt.subplot(3, 2, 6)
# Plot 1: Rentals over time
ax1.plot(bikes['dteday'], bikes['cnt'], color='steelblue', linewidth=1)
ax1.set_ylabel('Daily Rentals')
ax1.set_title('Daily Rentals Over Time')
ax1.grid(True, alpha=0.3)
# Plot 2: Rental distribution
ax2.hist(bikes['cnt'], bins=30, color='coral', edgecolor='black', alpha=0.7)
ax2.set_xlabel('Daily Rentals')
ax2.set_ylabel('Frequency')
ax2.set_title('Rental Distribution')
ax2.grid(True, alpha=0.3, axis='y')
# Plot 3: Temperature vs Rentals
ax3.scatter(bikes['temp'], bikes['cnt'], alpha=0.5, s=30, color='red')
corr_temp = bikes['temp'].corr(bikes['cnt'])
ax3.set_xlabel('Normalized Temperature')
ax3.set_ylabel('Daily Rentals')
ax3.set_title(f'Temperature vs Rentals (r = {corr_temp:.3f})')
ax3.grid(True, alpha=0.3)
# Plot 4: Humidity vs Rentals
ax4.scatter(bikes['hum'], bikes['cnt'], alpha=0.5, s=30, color='blue')
corr_hum = bikes['hum'].corr(bikes['cnt'])
ax4.set_xlabel('Humidity')
ax4.set_ylabel('Daily Rentals')
ax4.set_title(f'Humidity vs Rentals (r = {corr_hum:.3f})')
ax4.grid(True, alpha=0.3)
# Plot 5: Seasonal averages
season_avg = bikes.groupby('season')['cnt'].mean()
season_names = ['Spring', 'Summer', 'Fall', 'Winter']
ax5.bar(season_names, season_avg.values, color='orange', edgecolor='black')
ax5.set_ylabel('Average Daily Rentals')
ax5.set_title('Average Rentals by Season')
ax5.grid(True, alpha=0.3, axis='y')
# Plot 6: Weather condition impact
weather_avg = bikes.groupby('weathersit')['cnt'].mean()
weather_names = ['Clear', 'Mist', 'Light Rain']
ax6.barh(weather_names, weather_avg.values, color='green', edgecolor='black')
ax6.set_xlabel('Average Daily Rentals')
ax6.set_title('Weather Condition Impact')
ax6.grid(True, alpha=0.3, axis='x')
fig.suptitle('Bike Rental Analysis Dashboard', fontsize=18, fontweight='bold')
plt.tight_layout()
plt.show()Comparison: Subplots vs GridSpec
Using plt.subplots()
import matplotlib.pyplot as plt
# Simple regular grid
fig, axes = plt.subplots(2, 2, figsize=(10, 8))Best for: Regular grids where all subplots have same size.
Using GridSpec (Preview)
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
# Complex irregular layout
fig = plt.figure(figsize=(12, 8))
gs = GridSpec(3, 3, figure=fig)
ax1 = fig.add_subplot(gs[0, :]) # Top row, all columns
ax2 = fig.add_subplot(gs[1, :-1]) # Middle row, first 2 columns
ax3 = fig.add_subplot(gs[1:, -1]) # Right side, last 2 rows
ax4 = fig.add_subplot(gs[2, 0]) # Bottom-left
ax5 = fig.add_subplot(gs[2, 1]) # Bottom-middleBest for: Irregular layouts, different subplot sizes. (You will learn this in next lesson!)
Summary
You learned to create multi-panel figures with subplots:
plt.subplots(nrows, ncols)creates figure and axesfigis the entire canvas,axesare the plotting areas- Access subplots:
axes[i]for 1D,axes[row, col]for 2D grids - Plot on axes:
ax.plot(),ax.scatter(),ax.hist(), etc. - Share axes:
sharex=True,sharey=Truefor aligned comparisons - Iterate: Use
axes.flatto loop over 2D grids - Spacing:
plt.tight_layout()orplt.subplots_adjust() - Figure title:
fig.suptitle()for overall title - Use cases: Comparing data, multi-variable analysis, dashboards
Next Steps: In the next lesson, you will learn advanced grid layouts with GridSpec for complex dashboard designs.
Practice: Create a 2×2 subplot showing: (1) hourly rental pattern, (2) temperature distribution, (3) humidity vs rentals scatter, (4) weather condition bar chart.