画loss、acc曲线

  • Post author:
  • Post category:其他


# -*-coding:utf-8-*-
from matplotlib import pyplot as plt
import numpy as np


def ReadData(data_loc):
    epoch_list = []
    train_loss_list = []
    test_accuracy_list = []

    # open(data_loc,"r").readlines()
    with open(data_loc, "r") as f:
        linedata = f.readlines()

        for line_i in linedata:
            data = line_i.split(' ')
            print("data = ", data)
            epoch_i, train_loss_i, test_accuracy_i = data[1][0], data[3], data[6]
            epoch_list.append(int(epoch_i))
            train_loss_list.append(float(train_loss_i))
            test_accuracy_list.append(float(test_accuracy_i))

    # print(epoch_list)
    # print(train_loss_list)
    # print(test_loss_list)
    # print(test_accuracy_list)
    return epoch_list, train_loss_list,test_accuracy_list


def DrawLoss(train_loss_list, train_loss_list_2):
    plt.style.use('seaborn-ticks')
    # 更多style请看:https://blog.csdn.net/viviliving/article/details/107690844
    plt.title("Loss")
    plt.xlabel("epoch")
    plt.ylabel("loss")

    train_loss_list = train_loss_list[:10]

    epoch_list = [i for i in range(len(train_loss_list))]

    p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
    p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)

    plt.legend([p1, p2], ["with pretrain", "no pretrain"])
    plt.show()


def DrawAcc(train_loss_list, train_loss_list_2):
    plt.style.use('seaborn-ticks')
    plt.title("Accuracy")
    plt.xlabel("epoch")
    plt.ylabel("accuracy")

    train_loss_list = train_loss_list[:10]

    epoch_list = [i for i in range(len(train_loss_list))]

    p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
    p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)

    plt.legend([p1, p2], ["with pretrain", "no pretrain"])
    plt.show()


if __name__ == '__main__':
    data_1_loc = "loss_pre.txt"
    data_2_loc = "loss.txt"

    _, train_loss_list,  test_accuracy_list = ReadData(data_1_loc)
    _, train_loss_list_2,  test_accuracy_list_2 = ReadData(data_2_loc)

    DrawLoss(train_loss_list, train_loss_list_2)

    DrawAcc(test_accuracy_list, test_accuracy_list_2)

①上面的两个txt都长这个形式(是在训练的时候写入数据的):

②效果:

下面两张图展示的是有没有预训练参数时,resnet-18的训练效果



版权声明:本文为hamburgerJ原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。