【pytorch记录】SummaryWriter保存日志

  • Post author:
  • Post category:其他



在pytorch框架中,关于日志的保存,其中一种方式就是借鉴使用了tensorboard的库。所以我们需要在环境中安装tensorboard库,然后再在工程中进行该库的调用




1 安装与导入


  • 安装:

    conda install tensorboardX

    或者

    pip install tensorboardX


  • 导入

     from tensorboardX import SummaryWriter
     writer = SummaryWriter(logPath)
     ...
     writer.close()
    




2 添加需要保存标量数据

在这里插入图片描述



  • add_scalar(tag, scalar_value, global_step=None)

    从源码中我们能看到核心的三个参数为前三个。通俗的讲分别代表


    • tag:图的标签名,唯一标识

      scalar_value:y轴数据,标量数据的具体数值

      global_step:x轴数据,要记录的全局步长值


  • add_scalars(main_tag, tag_scalar_dit)

    多项标题记录方法,其中:


    • main_tag —— 该图的标签

      tag_salar_dict —— 字典形式的tag-scalar_value对


源码中也有例子:

from tensorboardX import SummaryWriter
import numpy as np

writer = SummaryWriter('run/logs')

max_epoch = 100
for x in range(max_epoch):

    writer.add_scalar('t/y=2x', x * 2, x)    #x*2为y轴数据,x为x轴数据
    writer.add_scalar('t/y=pow_2_x', 2^x, x)
    writer.add_scalars('scalar_group', {"xsinx": x * np.sin(x),
                                     "xcosx": x * np.cos(x)}, x)
    writer.close()



运行完该脚本后,运行tensorboard命令:

tensorboard --logdir=./run/


在这里插入图片描述

在浏览器中打开链接:【http://localhost:6006/】

在这里插入图片描述



3

添加需要保存图片数据

在这里插入图片描述


从源码中我们能看到

add_image

的主要参数如下。通俗的讲分别代表


  • tag:曲线图名字,唯一标识

  • img_tensor:图片数据,类型要求为 tensor/numpy/string 等

  • global_step:要记录的全局步长值

  • dataformats:图片输入的默认维度。注意是”CHW”

from tensorboardX import SummaryWriter
import numpy as np
img = np.zeros((3, 100, 100))
img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

img_HWC = np.zeros((100, 100, 3))
img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

writer = SummaryWriter('run/logs')
writer.add_image('my_image', img, 0)

# If you have non-default dimension setting, set the dataformats argument.
writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
writer.close()

在这里插入图片描述




4 直方图的记录


画直方图主要为了看参数的分布状态,使用

add_histogram(tag, values, global_step=None, bins=’tensorflow’, walltime=None)

,其中tag, value, global_step的含义同上,示例如下:

# 每个epoch,记录梯度,权值
for name, param in net.named_parameters():
    writer.add_histogram(name + '_grad', param.grad, epoch)
    writer.add_histogram(name + '_data', param, epoch)




5 网络结构的记录


展示结构图使用

add_graph(model, input_to_model=None, verbose=False)

writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
# 模型
fake_img = torch.randn(1, 3, 32, 32)
yolo = Yolo(classes=2)
writer.add_graph(yolo, fake_img)
writer.close()



版权声明:本文为magic_ll原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。