I have the following plot:
import matplotlib.pyplot as plt
fig2 = plt.figure()
ax3 = fig2.add_subplot(2,1,1)
ax4 = fig2.add_subplot(2,1,2)
ax4.loglog(x1, y1)
ax3.loglog(x2, y2)
ax3.set_ylabel('hello')
I want to be able to create axes labels and titles not just for each of the two subplots, but also common labels that span both subplots. For example, since both plots have identical axes, I only need one set of x and y- axes labels. I do want different titles for each subplot though.
I tried a few things but none of them worked right
You can create a big subplot that covers the two subplots and then set the common labels.
import random
import matplotlib.pyplot as plt
x = range(1, 101)
y1 = [random.randint(1, 100) for _ in range(len(x))]
y2 = [random.randint(1, 100) for _ in range(len(x))]
fig = plt.figure()
ax = fig.add_subplot(111) # The big subplot
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)
# Turn off axis lines and ticks of the big subplot
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax1.loglog(x, y1)
ax2.loglog(x, y2)
# Set common labels
ax.set_xlabel('common xlabel')
ax.set_ylabel('common ylabel')
ax1.set_title('ax1 title')
ax2.set_title('ax2 title')
plt.savefig('common_labels.png', dpi=300)
https://i.imgur.com/k0KSQ.png
Another way is using fig.text() to set the locations of the common labels directly.
import random
import matplotlib.pyplot as plt
x = range(1, 101)
y1 = [random.randint(1, 100) for _ in range(len(x))]
y2 = [random.randint(1, 100) for _ in range(len(x))]
fig = plt.figure()
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)
ax1.loglog(x, y1)
ax2.loglog(x, y2)
# Set common labels
fig.text(0.5, 0.04, 'common xlabel', ha='center', va='center')
fig.text(0.06, 0.5, 'common ylabel', ha='center', va='center', rotation='vertical')
ax1.set_title('ax1 title')
ax2.set_title('ax2 title')
plt.savefig('common_labels_text.png', dpi=300)
https://i.imgur.com/GRyzs.png
One simple way using subplots
:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(3, 4, sharex=True, sharey=True)
# add a big axes, hide frame
fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("common X")
plt.ylabel("common Y")
'off'
to False
with newer versions of Matplotlib (I have 2.2.2)
for ax in axes: ax.plot(x, y)
doesn't seem to do any good.
plt.setp()
will do the job:
# plot something
fig, axs = plt.subplots(3,3, figsize=(15, 8), sharex=True, sharey=True)
for i, ax in enumerate(axs.flat):
ax.scatter(*np.random.normal(size=(2,200)))
ax.set_title(f'Title {i}')
# set labels
plt.setp(axs[-1, :], xlabel='x axis label')
plt.setp(axs[:, 0], ylabel='y axis label')
https://i.stack.imgur.com/PuExX.png
plt.setp(axs[-1, :], xlabel='x axis label')
only modifies the label text (apparently). if you want to customize its text parameters do, for instance, ax.set_xlabel(None, size=12, weight='demibold', color='xkcd:lime green', labelpad=0.33)
inside the for loop. check matplotlib.text.Text for more **kwargs
.
New in matplotlib 3.4.0
There are now built-in methods to set common axis labels:
supxlabel fig.supxlabel('common x label')
supylabel fig.supylabel('common y label')
To reproduce OP's loglog
plots (common labels but separate titles):
x = np.arange(0.01, 10.01, 0.01)
y = 2 ** x
fig, (ax1, ax2) = plt.subplots(2, 1, constrained_layout=True)
ax1.loglog(y, x)
ax2.loglog(x, y)
# separate subplot titles
ax1.set_title('ax1.title')
ax2.set_title('ax2.title')
# common axis labels
fig.supxlabel('fig.supxlabel')
fig.supylabel('fig.supylabel')
https://i.stack.imgur.com/4EeGs.png
Wen-wei Liao's answer is good if you are not trying to export vector graphics or that you have set up your matplotlib backends to ignore colorless axes; otherwise the hidden axes would show up in the exported graphic.
My answer suplabel
here is similar to the fig.suptitle
which uses the fig.text
function. Therefore there is no axes artist being created and made colorless. However, if you try to call it multiple times you will get text added on top of each other (as fig.suptitle
does too). Wen-wei Liao's answer doesn't, because fig.add_subplot(111)
will return the same Axes object if it is already created.
My function can also be called after the plots have been created.
def suplabel(axis,label,label_prop=None,
labelpad=5,
ha='center',va='center'):
''' Add super ylabel or xlabel to the figure
Similar to matplotlib.suptitle
axis - string: "x" or "y"
label - string
label_prop - keyword dictionary for Text
labelpad - padding from the axis (default: 5)
ha - horizontal alignment (default: "center")
va - vertical alignment (default: "center")
'''
fig = pylab.gcf()
xmin = []
ymin = []
for ax in fig.axes:
xmin.append(ax.get_position().xmin)
ymin.append(ax.get_position().ymin)
xmin,ymin = min(xmin),min(ymin)
dpi = fig.dpi
if axis.lower() == "y":
rotation=90.
x = xmin-float(labelpad)/dpi
y = 0.5
elif axis.lower() == 'x':
rotation = 0.
x = 0.5
y = ymin - float(labelpad)/dpi
else:
raise Exception("Unexpected axis: x or y")
if label_prop is None:
label_prop = dict()
pylab.text(x,y,label,rotation=rotation,
transform=fig.transFigure,
ha=ha,va=va,
**label_prop)
Here is a solution where you set the ylabel of one of the plots and adjust the position of it so it is centered vertically. This way you avoid problems mentioned by KYC.
import numpy as np
import matplotlib.pyplot as plt
def set_shared_ylabel(a, ylabel, labelpad = 0.01):
"""Set a y label shared by multiple axes
Parameters
----------
a: list of axes
ylabel: string
labelpad: float
Sets the padding between ticklabels and axis label"""
f = a[0].get_figure()
f.canvas.draw() #sets f.canvas.renderer needed below
# get the center position for all plots
top = a[0].get_position().y1
bottom = a[-1].get_position().y0
# get the coordinates of the left side of the tick labels
x0 = 1
for at in a:
at.set_ylabel('') # just to make sure we don't and up with multiple labels
bboxes, _ = at.yaxis.get_ticklabel_extents(f.canvas.renderer)
bboxes = bboxes.inverse_transformed(f.transFigure)
xt = bboxes.x0
if xt < x0:
x0 = xt
tick_label_left = x0
# set position of label
a[-1].set_ylabel(ylabel)
a[-1].yaxis.set_label_coords(tick_label_left - labelpad,(bottom + top)/2, transform=f.transFigure)
length = 100
x = np.linspace(0,100, length)
y1 = np.random.random(length) * 1000
y2 = np.random.random(length)
f,a = plt.subplots(2, sharex=True, gridspec_kw={'hspace':0})
a[0].plot(x, y1)
a[1].plot(x, y2)
set_shared_ylabel(a, 'shared y label (a. u.)')
https://i.stack.imgur.com/d5iTr.png
# list loss and acc are your data
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.plot(iteration1, loss)
ax2.plot(iteration2, acc)
ax1.set_title('Training Loss')
ax2.set_title('Training Accuracy')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Loss')
ax2.set_xlabel('Iteration')
ax2.set_ylabel('Accuracy')
The methods in the other answers will not work properly when the yticks are large. The ylabel will either overlap with ticks, be clipped on the left or completely invisible/outside of the figure.
I've modified Hagne's answer so it works with more than 1 column of subplots, for both xlabel and ylabel, and it shifts the plot to keep the ylabel visible in the figure.
def set_shared_ylabel(a, xlabel, ylabel, labelpad = 0.01, figleftpad=0.05):
"""Set a y label shared by multiple axes
Parameters
----------
a: list of axes
ylabel: string
labelpad: float
Sets the padding between ticklabels and axis label"""
f = a[0,0].get_figure()
f.canvas.draw() #sets f.canvas.renderer needed below
# get the center position for all plots
top = a[0,0].get_position().y1
bottom = a[-1,-1].get_position().y0
# get the coordinates of the left side of the tick labels
x0 = 1
x1 = 1
for at_row in a:
at = at_row[0]
at.set_ylabel('') # just to make sure we don't and up with multiple labels
bboxes, _ = at.yaxis.get_ticklabel_extents(f.canvas.renderer)
bboxes = bboxes.inverse_transformed(f.transFigure)
xt = bboxes.x0
if xt < x0:
x0 = xt
x1 = bboxes.x1
tick_label_left = x0
# shrink plot on left to prevent ylabel clipping
# (x1 - tick_label_left) is the x coordinate of right end of tick label,
# basically how much padding is needed to fit tick labels in the figure
# figleftpad is additional padding to fit the ylabel
plt.subplots_adjust(left=(x1 - tick_label_left) + figleftpad)
# set position of label,
# note that (figleftpad-labelpad) refers to the middle of the ylabel
a[-1,-1].set_ylabel(ylabel)
a[-1,-1].yaxis.set_label_coords(figleftpad-labelpad,(bottom + top)/2, transform=f.transFigure)
# set xlabel
y0 = 1
for at in axes[-1]:
at.set_xlabel('') # just to make sure we don't and up with multiple labels
bboxes, _ = at.xaxis.get_ticklabel_extents(fig.canvas.renderer)
bboxes = bboxes.inverse_transformed(fig.transFigure)
yt = bboxes.y0
if yt < y0:
y0 = yt
tick_label_bottom = y0
axes[-1, -1].set_xlabel(xlabel)
axes[-1, -1].xaxis.set_label_coords((left + right) / 2, tick_label_bottom - labelpad, transform=fig.transFigure)
It works for the following example, while Hagne's answer won't draw ylabel (since it's outside of the canvas) and KYC's ylabel overlaps with the tick labels:
import matplotlib.pyplot as plt
import itertools
fig, axes = plt.subplots(3, 4, sharey='row', sharex=True, squeeze=False)
fig.subplots_adjust(hspace=.5)
for i, a in enumerate(itertools.chain(*axes)):
a.plot([0,4**i], [0,4**i])
a.set_title(i)
set_shared_ylabel(axes, 'common X', 'common Y')
plt.show()
Alternatively, if you are fine with colorless axis, I've modified Julian Chen's solution so ylabel won't overlap with tick labels.
Basically, we just have to set ylims of the colorless so it matches the largest ylims of the subplots so the colorless tick labels sets the correct location for the ylabel.
Again, we have to shrink the plot to prevent clipping. Here I've hard coded the amount to shrink, but you can play around to find a number that works for you or calculate it like in the method above.
import matplotlib.pyplot as plt
import itertools
fig, axes = plt.subplots(3, 4, sharey='row', sharex=True, squeeze=False)
fig.subplots_adjust(hspace=.5)
miny = maxy = 0
for i, a in enumerate(itertools.chain(*axes)):
a.plot([0,4**i], [0,4**i])
a.set_title(i)
miny = min(miny, a.get_ylim()[0])
maxy = max(maxy, a.get_ylim()[1])
# add a big axes, hide frame
# set ylim to match the largest range of any subplot
ax_invis = fig.add_subplot(111, frameon=False)
ax_invis.set_ylim([miny, maxy])
# hide tick and tick label of the big axis
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.xlabel("common X")
plt.ylabel("common Y")
# shrink plot to prevent clipping
plt.subplots_adjust(left=0.15)
plt.show()
Success story sharing
ax
has to be created beforeax1
andax2
, otherwise the big plot will cover up the small plots.