我对这段代码的工作方式有点困惑:
fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()
在这种情况下,无花果轴如何工作?它有什么作用?
还有为什么这不能做同样的事情:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
有几种方法可以做到这一点。 subplots
方法创建图形以及随后存储在 ax
数组中的子图。例如:
import matplotlib.pyplot as plt
x = range(10)
y = range(10)
fig, ax = plt.subplots(nrows=2, ncols=2)
for row in ax:
for col in row:
col.plot(x, y)
plt.show()
https://i.stack.imgur.com/2JxAs.png
然而,这样的事情也可以工作,但它不是那么“干净”,因为你正在创建一个带有子图的图形,然后在它们之上添加:
fig = plt.figure()
plt.subplot(2, 2, 1)
plt.plot(x, y)
plt.subplot(2, 2, 2)
plt.plot(x, y)
plt.subplot(2, 2, 3)
plt.plot(x, y)
plt.subplot(2, 2, 4)
plt.plot(x, y)
plt.show()
https://i.stack.imgur.com/0EiUD.png
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 2)
ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()
https://i.stack.imgur.com/jMSQV.jpg
您还可以在 subplots 调用中解压缩轴
并设置是否要在子图之间共享 x 和 y 轴
像这样:
import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()
ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()
https://i.stack.imgur.com/kvOOL.png
您可能对从 matplotlib 版本 2.1 开始该问题的第二个代码也可以正常工作这一事实感兴趣。
从 change log:
Figure 类现在有 subplots 方法 Figure 类现在有一个 subplots() 方法,其行为与 pyplot.subplots() 相同,但在现有图形上。
例子:
import matplotlib.pyplot as plt
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
plt.show()
阅读文档:matplotlib.pyplot.subplots
pyplot.subplots()
返回一个元组 fig, ax
,它使用符号解压缩在两个变量中
fig, axes = plt.subplots(nrows=2, ncols=2)
编码:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
不起作用,因为 subplots()
是 pyplot
中的函数,而不是对象 Figure
的成员。
依次遍历所有子图:
fig, axes = plt.subplots(nrows, ncols)
for ax in axes.flatten():
ax.plot(x,y)
访问特定索引:
for row in range(nrows):
for col in range(ncols):
axes[row,col].plot(x[row], y[col])
带有熊猫的子图
这个答案适用于带有 pandas 的子图,它使用 matplotlib 作为默认的绘图后端。
以下是从 pandas.DataFrame 实现 1. 和 2. 开始创建子图的四个选项,用于宽格式的数据,为每列创建子图。实现 3. 和 4. 用于长格式数据,为列中的每个唯一值创建子图。
实现 1. 和 2. 用于宽格式的数据,为每列创建子图。
实现 3. 和 4. 用于长格式数据,为列中的每个唯一值创建子图。
在 python 3.8.11、pandas 1.3.2、matplotlib 3.4.3、seaborn 0.11.2 中测试
导入和数据
import seaborn as sns # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]
orbital_period mass distance
0 269.300 7.10 77.40
1 874.774 2.21 56.95
2 763.000 2.60 19.84
3 326.030 19.40 110.62
4 516.220 10.50 119.47
# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()
variable value
0 orbital_period 269.300
1 orbital_period 874.774
2 orbital_period 763.000
3 orbital_period 326.030
4 orbital_period 516.220
1. subplots=True 和布局,对于每一列
在 pandas.DataFrame.plot 中使用参数 subplots=True 和 layout=(rows, cols)
此示例使用 kind='density',但 kind 有不同的选项,这适用于所有选项。如果不指定种类,则默认使用折线图。
ax 是 pandas.DataFrame.plot 返回的 AxesSubplot 数组
如果需要,请参阅如何获取 Figure 对象。如何保存熊猫子图
如何保存熊猫子图
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))
# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure()
# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
ax.set_title(title)
fig.tight_layout()
plt.show()
2. plt.subplots,对于每一列
使用 matplotlib.pyplot.subplots 创建一个 Axes 数组,然后将 axes[i, j] 或 axes[n] 传递给 ax 参数。此选项使用 pandas.DataFrame.plot,但可以使用其他轴级别绘图调用作为替代(例如 sns.kdeplot、plt.plot 等)。使用 .ravel 或 . 将 Axes 的子图数组折叠成一维是最简单的。变平。请参阅 .ravel 与 .flatten。应用于每个轴的任何需要迭代的变量都与 .zip 组合在一起(例如,列、轴、颜色、调色板等)。每个对象的长度必须相同。
此选项使用 pandas.DataFrame.plot,但可以使用其他轴级别绘图调用作为替代(例如 sns.kdeplot、plt.plot 等)
使用 .ravel 或 .flatten 将 Axes 的子图数组折叠成一维是最简单的。请参阅 .ravel 与 .flatten。
应用于每个轴的任何需要迭代的变量都与 .zip 组合在一起(例如,列、轴、颜色、调色板等)。每个对象的长度必须相同。
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
cols = df.columns # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for col, color, ax in zip(cols, colors, axes):
df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
ax.legend()
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
1. 和 2 的结果。
https://i.stack.imgur.com/D64Me.png
3. plt.subplots,对于 .groupby 中的每个组
这与 2. 类似,不同之处在于它将颜色和轴压缩到 .groupby 对象。
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
dfg = dfm.groupby('variable') # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for (group, data), color, ax in zip(dfg, colors, axes):
data.plot(kind='density', ax=ax, color=color, title=group, legend=False)
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
https://i.stack.imgur.com/Nw1Yo.png
4. seaborn 人物级剧情
使用 seaborn 图形级别图,并使用 col 或 row 参数。 seaborn 是 matplotlib 的高级 API。请参阅 seaborn:API 参考
p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))
其他答案很好,这个答案是一个可能有用的组合。
import numpy as np
import matplotlib.pyplot as plt
# Optional: define x for all the sub-plots
x = np.linspace(0,2*np.pi,100)
# (1) Prepare the figure infrastructure
fig, ax_array = plt.subplots(nrows=2, ncols=2)
# flatten the array of axes, which makes them easier to iterate through and assign
ax_array = ax_array.flatten()
# (2) Plot loop
for i, ax in enumerate(ax_array):
ax.plot(x , np.sin(x + np.pi/2*i))
#ax.set_title(f'plot {i}')
# Optional: main title
plt.suptitle('Plots')
https://i.stack.imgur.com/663hv.png
概括
准备图形基础结构 获取 ax_array,一个子图数组 展平数组以便在一个“for 循环”中使用它 绘图循环 在展平的 ax_array 上循环以更新子图 可选:使用枚举来跟踪子图编号 一旦展平,每个 ax_array可以从 0 到 nrows x ncols -1 单独索引(例如 ax_array[0]、ax_array[1]、ax_array[2]、ax_array[3])。
这是一个简单的解决方案
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
sp.plot(range(10))
https://i.stack.imgur.com/5dBZe.png
将坐标区数组转换为 1D
使用 plt.subplots(nrows, ncols) 生成子图,其中 nrows 和 ncols 均大于 1,返回
在 nrows=1 或 ncols=1 的情况下,不需要展平坐标轴,因为坐标轴已经是一维的,这是默认参数 squeeze=True 的结果
访问对象的最简单方法是使用 .ravel()、.flatten() 或 .flat 将数组转换为一维。 .ravel 与 .flatten flatten 总是返回一个副本。 ravel 尽可能返回原始数组的视图。
.ravel 与 .flatten flatten 总是返回一个副本。 ravel 尽可能返回原始数组的视图。
flatten 总是返回一个副本。
ravel 尽可能返回原始数组的视图。
将轴数组转换为 1-d 后,有多种绘图方法。
这个答案与 seaborn 轴级图相关,它具有 ax= 参数(例如 sns.barplot(..., ax=ax[0])。seaborn 是 matplotlib 的高级 API。参见 Figure-level vs . 轴级函数和 seaborn 未在定义的子图中绘制
seaborn 是 matplotlib 的高级 API。请参阅图形级与轴级函数,并且 seaborn 未在定义的子图中绘制
import matplotlib.pyplot as plt
import numpy as np # sample data only
# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]
# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)
# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
# convert the array to 1 dimension
axes = axes.ravel()
# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
dtype=object)
遍历展平的数组如果子图比数据多,这将导致 IndexError: list index out of range 请尝试选项 3。改为,或选择轴的子集(例如 axes[:-2])
for i, ax in enumerate(axes):
ax.plot(x_data[i], y_data[i])
按索引访问每个轴
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
索引数据和轴
for i in range(len(x_data)):
axes[i].plot(x_data[i], y_data[i])
将轴和数据压缩在一起,然后遍历元组列表
for ax, x, y in zip(axes, x_data, y_data):
ax.plot(x, y)
输出
https://i.stack.imgur.com/AvhBx.png
一个选项是将每个轴分配给一个变量 fig, (ax1, ax2, ax3) = plt.subplots(1, 3)。然而,正如所写,这只适用于 nrows=1 或 ncols=1 的情况。这是基于 plt.subplots 返回的数组的形状,很快就会变得很麻烦。 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) 对于 2 x 2 数组。此选项对于两个子图最有用(例如:fig, (ax1, ax2) = plt.subplots(1, 2) 或 fig, (ax1, ax2) = plt.subplots(2, 1))。对于更多子图,展平和迭代轴数组会更有效。
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) 对于 2 x 2 数组。
此选项对于两个子图最有用(例如:fig, (ax1, ax2) = plt.subplots(1, 2) 或 fig, (ax1, ax2) = plt.subplots(2, 1))。对于更多子图,展平和迭代轴数组会更有效。
如果您真的想使用循环,请执行以下操作:
def plot(data):
fig = plt.figure(figsize=(100, 100))
for idx, k in enumerate(data.keys(), 1):
x, y = data[k].keys(), data[k].values
plt.subplot(63, 10, idx)
plt.bar(x, y)
plt.show()
另一个简洁的解决方案是:
// set up structure of plots
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,10))
// for plot 1
ax1.set_title('Title A')
ax1.plot(x, y)
// for plot 2
ax2.set_title('Title B')
ax2.plot(x, y)
// for plot 3
ax3.set_title('Title C')
ax3.plot(x,y)
nrows=1
或 ncols=1
。这对于 fig, (ax1, ax2) = plt.subplots(2, 1)
或 fig, (ax1, ax2) = plt.subplots(1, 2)
以外的任何内容都会很快变得低效,如本 answer 中所述
ax
是什么,但不知道fig
是什么。这些是什么?matplotlib.figure.Figure
类,您可以通过它对绘制的图形进行大量操作。例如,您可以将颜色条添加到特定的子图,您可以更改所有子图后面的背景颜色。您可以修改这些子图的布局或为其添加新的小斧头。最好你可能想要一个可以通过fig.suptitle(title)
方法获得的所有子图的主标题。最后,一旦您对情节感到满意,就可以使用fig.savefig
方法保存它。 @Leevo