TypeError: forward() missing 1 required positional argument: ‘target‘

  • Post author:
  • Post category:其他


los = loss(np.array([output, label]))

# 获取损失函数
    loss = torch.nn.CrossEntropyLoss()
# 开始训练
    for epoch in range(args.num_epoch):
        loss_sum = []
        accuracies = []
        for batch_id, (spec_mag, label) in enumerate(train_loader):
            spec_mag = spec_mag.to(device)
            label = label.to(device).long()
            output = model(spec_mag)
            # 计算损失值
            los = loss(np.array([output, label]))
            optimizer.zero_grad()
            los.backward()
            optimizer.step()

            # 计算准确率
            output = torch.nn.functional.softmax(output)
            output = output.data.cpu().numpy()
            output = np.argmax(output, axis=1)
            label = label.data.cpu().numpy()
            acc = np.mean((output == label).astype(int))
            accuracies.append(acc)
            loss_sum.append(los)
            if batch_id % 100 == 0:
                print('[%s] Train epoch %d, batch: %d/%d, loss: %f, accuracy: %f' % (
                    datetime.now(), epoch, batch_id, len(train_loader), sum(loss_sum) / len(loss_sum), sum(accuracies) / len(accuracies)))
            
        

        Loss_list.append(los / (len(train_dataset)))
        Accuracy_list.append(acc )
        for index,item in enumerate(Accuracy_list):
              print(index,item)

            
            
            
        scheduler.step()
        # 评估模型
        acc = test(model, test_loader, device)
        print('='*70)
        print('[%s] Test %d, accuracy: %f' % (datetime.now(), epoch, acc))
        print('='*70)
        model_path = os.path.join(args.save_model, 'resnet34.pth')
        if not os.path.exists(os.path.dirname(model_path)):
            os.makedirs(os.path.dirname(model_path))
        torch.jit.save(torch.jit.script(model), model_path)

与CrossEntropyLoss() 有关

【Pytorch】交叉熵损失函数 CrossEntropyLoss() 详解


【Pytorch】交叉熵损失函数 CrossEntropyLoss() 详解_想变厉害的大白菜的博客-CSDN博客_crossentropyloss pytorch

nn.CrossEntropyLoss()可接受两种输入


nn.CrossEntropyLoss()可接受两种输入_ImangoCloud的博客-CSDN博客_crossentropyloss输入

input=[batch_size,num_classes]

target=[num_classes]



版权声明:本文为m0_53342923原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。