matplotlib绘制局部放大图

  • Post author:
  • Post category:其他


导入工具包

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

配置matplotlib字体

config = {"font.family": "SimSun",
          "font.size": 14,
          "mathtext.fontset": "stix"}
mpl.rcParams.update(config)

导入数据

x_data = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 
          16, 17, 18, 19, 20, 21, 22, 23, 24]
y_data_1 = [8.179,7.345,7.029,6.751,6.611,6.507,6.319,6.196,
            6.086,6.006,5.923,5.81,5.76,5.705,5.644,5.583,
            5.523,5.436,5.485,5.378,5.343,5.309,5.257,]
y_data_2 = [7.191,6.637,6.292,6.112,6.094,5.74,5.709,5.667,
            5.583,5.506,5.41,5.307,5.239,5.303,5.282,5.11,
            5.075,5.083,4.971,4.976,4.993,4.963,5.138]

创建画布

fig, ax = plt.subplots(1, 1)

绘图

ax.plot(x_data, y_data_1, linestyle=":",
        color="r", linewidth=1.5,
        marker="o",
        markersize=5,
        markeredgecolor="black",
        markerfacecolor="C3")

ax.plot(x_data, y_data_2, linestyle=":",
        color="g", linewidth=1.5,
        marker="o",
        markersize=5,
        markeredgecolor="black",
        markerfacecolor="C2")
ax.legend(labels=["训练集", "测试集"], ncol=2, loc="lower left")
plt.xlabel("训练轮次", fontsize=14, fontdict={"family": "SimSun", "size": 14})
plt.ylabel("损失", fontsize=14, fontdict={"family": "SimSun", "size": 14})

在大图上绘制要放大的区域

sx = [20, 24, 24, 20, 20]
sy = [4.9, 4.9, 5.5, 5.5, 4.9]
plt.plot(sx, sy, lw=2, c="black")

插入子坐标并绘图

axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
                   bbox_to_anchor=(0.4, 0.6, 1, 1),
                   bbox_transform=ax.transAxes)
axins.plot(x_data, y_data_1, linestyle=":",
           color="r", linewidth=1.5,
           marker="o",
           markersize=5,
           markeredgecolor="black",
           markerfacecolor="C3")
axins.plot(x_data, y_data_2, linestyle=":",
           color="g", linewidth=1.5,
           marker="o",
           markersize=5,
           markeredgecolor="black",
           markerfacecolor="C2")
axins.set_xlim(20, 24)
axins.set_ylim(4.9, 5.5)

给极值处添加标签

axins.annotate(
    "训练集上的最小损失5.257",
    ha="center",
    va="bottom",
    xytext=(20, 5.2),
    xy=(24, 5.257),
    arrowprops={'facecolor': 'red', 'shrink': 0.05}
)
axins.annotate(
    "测试集上的最小损失4.963",
    ha="center",
    va="bottom",
    xytext=(20, 5.05),
    xy=(23, 4.963),
    arrowprops={'facecolor': 'red', 'shrink': 0.05}
)

绘制放大图与原图像之间的连线

xy = (20, 5.5)
xy2 = (20, 4.9)
con = ConnectionPatch(xyA=xy2, xyB=xy, coordsA="data", coordsB="data",
        axesA=axins, axesB=ax)
axins.add_artist(con)

xy = (24, 5.5)
xy2 = (24, 4.9)
con = ConnectionPatch(xyA=xy2, xyB=xy, coordsA="data", coordsB="data",
        axesA=axins, axesB=ax)
axins.add_artist(con)

保存图像

plt.savefig("loss.PNG", dpi=300, bbox_inches="tight")
plt.show()



版权声明:本文为weixin_47986386原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。