Pytorch Lightning 踩坑记录

  • Post author:
  • Post category:其他


最近两周开始了解和上手学习pytorch lightning,这个框架,csdn和知乎资料都比较少,而且框架相对年轻,还是有不少该改进的地方

lightning 可以和torch 兼容,在某些方面反而有冲突,可能需要自己写callback 函数实现

今天记录一下自己的踩坑(不分先后)



模型test 无输出

在这里插入图片描述

正确代码

def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    preds = torch.argmax(logits, dim=1)
    acc = accuracy(preds, y)
    f1_score = f1(preds, y, num_classes=self.num_classes)
    # self.log("f1", f1_score)
    self.log("test_acc", acc,logger=False,on_epoch=True,)

必须

logger=True

这样的话也会记录到tensotboard,之前设置为false 就是觉得tensotboard中有一个标量不是曲线,只有一个点太难看,结果就没有输出了

在这里插入图片描述



版权声明:本文为weixin_44831720原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。