1. 相关理论概述
Temporal Ensembling方法通过预测的标签的EMA(exponential moving average),并且通过保证ensemle 模型和 trained模型预测标签的连续一致性,从而保证训练得到的ensemle模型尽可能的接近groud truth模型。这里可以理解为,如果如果模型是正确的,那么前后两个模型的预测标签应该是接近的,并且变化较小的,那么使模型向使两个模型预测结果接近的方向移动,就是向groudtruth model移动。这种方法,每一个epoch标签数据仅仅会改变一次,对于大规模数据,或者在线学习问题,该方法就不能很好的适用。论文《Mean teachers better role models: Weight-averaged…》提出了平均权重的方法,而不是Temporal Ensembling中采用的label平均的方法,可以在每一个training step更新teacher model,及时的指导student model的学习。在ImageNet 2012上,使用10%的labels,将top5的精度误差率从35.24%下降9.11%。
2. 算法概述
网络整体的架构包括两个部分student model和teacher model:student model的网络参数通过学习,梯度下降获得。teacher model的网络参数通过student model的网络参数的moving average得到。
-
teacher model的网络参数的更新方法:通过student model网络参数的moving average得到
、 - student model的网络参数更新方法:通过损失函数的梯度下降更新参数得到。其中损失函数包括两个部分:有监督损失函数,保证有标签训练数据拟合;第二部分是无监督损失函数,主要是保证student model的预测结果和teacher model的预测结果尽量的相似。因为teacher model的参数是student model的网络参数的moving average,所以,对于任何新来的数据,预测结果都不应该有太大的抖动。如果如果模型是正确的,那么前后两个模型的预测标签应该是接近的,并且变化较小的,那么使模型向使两个模型预测结果接近的方向移动,就是向groudtruth model移动。
3. 算法流程
假设有一批训练样本X1,X2,其中X1使有标签数据(对应标签是z1),X2使无标签数据。具体的训练过程如下:
- 把这一批样本作为student网络输入,然后分别得到输出的标签:ys1,ys2;
- 构造对于有标签数据X1的损失函数,有标签分类损失函数L1(z1,ys1);
- 把这批数据作为teacher model的输入,得到输出的标签yt1,yt2;
-
构造无监督损失函数L2,论文中采用MSE损失函数:
-
总损失函数L1+L2梯度下降,更新student model的网络参数,通过moving average更新teacher model的网络参数
α选择
在网络开始训练阶段,由于参数是随机初始化而来,student的参数肯定是不正确的,所以构成的teacher的参数也是不正确的。应该以student学习到的为准,所以α值应该从零开始,随着网络的训练,student达到一定的准确率之后,就可以采用ensemble思想。可以以teacher网络的参数为准,最终达到α为0.99。
augmentation以及添加噪声
对于同样的样本,通过augmentation和加入噪声后,得到的teacher和student的输入是不同的,所以两者的输出也应该是不同的,这样训练出来的网络具有对噪声的鲁棒性。