-
理解AdamW
-
我们先弄清楚什么是weight decay
-
其实是在损失函数求导后,放在正则项前面的系数,比如L2正则,我们看一下weight decay的位置
-
我
们
可
以
认
为
λ
就
是
w
e
i
g
h
t
d
e
c
a
y
min
w
L
2
(
w
)
=
min
w
f
(
w
)
+
λ
2
n
∑
i
=
1
n
w
i
2
L
2
′
(
w
)
=
f
′
(
w
)
+
λ
n
∑
i
=
1
n
w
i
我们可以认为\lambda就是weight\ decay\\ \min_wL_2(w)=\min_wf(w)+\frac{\lambda}{2n}\sum_{i=1}^nw_i^2\\ L_2^{‘}(w)=f^{‘}(w)+\frac{\lambda}{n}\sum_{i=1}^nw_i
我们可以认为λ就是weight decaywminL2(w)=wminf(w)+2nλi=1∑nwi2L2′(w)=f′(w)+nλi=1∑nwi
- AdamW是在Adam+L2正则化的基础上进行改进的算法。使用Adam优化带L2正则的损失并不有效。如果引入L2正则项,在计算梯度的时候会加上对正则项求梯度的结果。
- 那么如果本身比较大的一些权重对应的梯度也会比较大,由于Adam计算步骤中减去项会除以梯度平方的累积开根号,使得减去项偏小。按常理说,越大的权重应该惩罚越大,但是在Adam并不是这样。分子分母相互抵消掉了。
- 而权重衰减对所有的权重都采用相同的系数进行更新,越大的权重显然惩罚越大。
- 在常见的深度学习库中只提供了L2正则,并没有提供权重衰减的实现。
- paper地址
-
Adam+L2 VS AdamW
图片中红色是传统的Adam+L2 regularization的方式,绿色是Adam + weight decay的方式。可以看出两个方法的区别仅在于”系数乘以上一步参数值”(这一项实际上就是权重乘以L2项的导数,因为
x
2
x^2
x2的导数是本身x。)这一项的位置。
再结合代码来看一下AdamW的具体实现。
以下代码来自https://github.com/macanv/BERT-BiLSTM-CRF-NER/blob/master/bert_base/bert/optimization.py中的AdamWeightDecayOptimizer中的apply_gradients函数中,BERT中的优化器就是使用这个方法。
在代码中也做了一些注释用于对应之前给出的Adam简化版公式,方便理解。可以看出update += self.weight_decay_rate * param这一句是Adam中没有的,也就是Adam中绿色的部分对应的代码,weightdecay这一步是是发生在Adam中需要被更新的参数update计算之后,并且在乘以学习率learning_rate之前,这和图片中的伪代码的计算顺序是完全一致的。总之一句话,如果使用了weightdecay就不必再使用L2正则化了。
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""See base class."""
assignments = []
for (grad, param) in grads_and_vars:
if grad is None or param is None:
continue
param_name = self._get_variable_name(param.name)
m = tf.get_variable(
name=param_name + "/adam_m",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
v = tf.get_variable(
name=param_name + "/adam_v",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
# Standard Adam update.
next_m = (
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
next_v = (
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
tf.square(grad)))
update = next_m / (tf.sqrt(next_v) + self.epsilon)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want ot decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if self._do_use_weight_decay(param_name):
update += self.weight_decay_rate * param
update_with_lr = self.learning_rate * update
next_param = param - update_with_lr
assignments.extend(
[param.assign(next_param),
m.assign(next_m),
v.assign(next_v)])
return tf.group(*assignments, name=name)