导入工具包
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
配置matplotlib字体
config = {"font.family": "SimSun",
"font.size": 14,
"mathtext.fontset": "stix"}
mpl.rcParams.update(config)
导入数据
x_data = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24]
y_data_1 = [8.179,7.345,7.029,6.751,6.611,6.507,6.319,6.196,
6.086,6.006,5.923,5.81,5.76,5.705,5.644,5.583,
5.523,5.436,5.485,5.378,5.343,5.309,5.257,]
y_data_2 = [7.191,6.637,6.292,6.112,6.094,5.74,5.709,5.667,
5.583,5.506,5.41,5.307,5.239,5.303,5.282,5.11,
5.075,5.083,4.971,4.976,4.993,4.963,5.138]
创建画布
fig, ax = plt.subplots(1, 1)
绘图
ax.plot(x_data, y_data_1, linestyle=":",
color="r", linewidth=1.5,
marker="o",
markersize=5,
markeredgecolor="black",
markerfacecolor="C3")
ax.plot(x_data, y_data_2, linestyle=":",
color="g", linewidth=1.5,
marker="o",
markersize=5,
markeredgecolor="black",
markerfacecolor="C2")
ax.legend(labels=["训练集", "测试集"], ncol=2, loc="lower left")
plt.xlabel("训练轮次", fontsize=14, fontdict={"family": "SimSun", "size": 14})
plt.ylabel("损失", fontsize=14, fontdict={"family": "SimSun", "size": 14})
在大图上绘制要放大的区域
sx = [20, 24, 24, 20, 20]
sy = [4.9, 4.9, 5.5, 5.5, 4.9]
plt.plot(sx, sy, lw=2, c="black")
插入子坐标并绘图
axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
bbox_to_anchor=(0.4, 0.6, 1, 1),
bbox_transform=ax.transAxes)
axins.plot(x_data, y_data_1, linestyle=":",
color="r", linewidth=1.5,
marker="o",
markersize=5,
markeredgecolor="black",
markerfacecolor="C3")
axins.plot(x_data, y_data_2, linestyle=":",
color="g", linewidth=1.5,
marker="o",
markersize=5,
markeredgecolor="black",
markerfacecolor="C2")
axins.set_xlim(20, 24)
axins.set_ylim(4.9, 5.5)
给极值处添加标签
axins.annotate(
"训练集上的最小损失5.257",
ha="center",
va="bottom",
xytext=(20, 5.2),
xy=(24, 5.257),
arrowprops={'facecolor': 'red', 'shrink': 0.05}
)
axins.annotate(
"测试集上的最小损失4.963",
ha="center",
va="bottom",
xytext=(20, 5.05),
xy=(23, 4.963),
arrowprops={'facecolor': 'red', 'shrink': 0.05}
)
绘制放大图与原图像之间的连线
xy = (20, 5.5)
xy2 = (20, 4.9)
con = ConnectionPatch(xyA=xy2, xyB=xy, coordsA="data", coordsB="data",
axesA=axins, axesB=ax)
axins.add_artist(con)
xy = (24, 5.5)
xy2 = (24, 4.9)
con = ConnectionPatch(xyA=xy2, xyB=xy, coordsA="data", coordsB="data",
axesA=axins, axesB=ax)
axins.add_artist(con)
保存图像
plt.savefig("loss.PNG", dpi=300, bbox_inches="tight")
plt.show()
版权声明:本文为weixin_47986386原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。