回顾Adam优化器

AdamW是Adam的升级版,我们先回顾一下 Adam

假如你在训练一个神经网络,目标是找到损失函数的最小损失值。优化器就是你找最小损失值的策略,我们可以把这个过程比作下山,找到这片区域的最低山谷。

Adam优化器是一个不错的策略(放以前来看很好,但现在有AdamW了),它结合了两个方法:

  • 动量:就像下山有惯性一样,Adam 会记住之前几步的方向,这样就走得更稳更快,用术语说就是加速收敛
  • 自适应学习率:山势陡峭(梯度大)的地方,它就小步走,防止摔倒(降低学习率);山势平缓(梯度小)的地方,它就迈大步,加快速度(增大学习率)。它为每个参数都单独调整步幅(也就是学习率)。

Adam 的公式比较复杂(毕竟不复杂就不会这么出名),但核心的参数更新步骤可以简化为:

新参数=旧参数学习率动量梯度平方+一个小常数新参数 = 旧参数 - \frac{学习率 * 动量} {\sqrt{梯度平方} + 一个小常数}

它有一个问题:处理“权重衰减”的方式有缺陷。要理解 AdamW,必须先理解“权重衰减”和“L2正则化”,这两者正是导致Adam出bug的地方,也是AdamW做了优化的地方。

补一下知识:

L2正则化是为了防止模型过拟合,在损失函数后面直接添加一个惩罚项来实现的。
它的公式:最终损失 = 原始损失 + (λ/2) * ||权重||²。在计算梯度时,这个惩罚项会产生一个额外的梯度 λ * 权重,它会推动权重值向零缩小。

权重衰减同样是为了防止过拟合,让权重变小,它在参数更新时,直接从权重里减去一小部分衰减率 * 权重
公式:新权重 = 旧权重 - 学习率 * 梯度 - 衰减率 * 旧权重

在标准的SGD优化器中,L2正则化和权重衰减是等价的,因为我们设置衰减率 = 学习率 * λ,睿智的你就会发现,在梯度计算时这两者的式子时一样的。

但是,在Adam这里它们就不等价了。问题就出在Adam的自适应学习率这里,它会给每个参数都计算各自的步长。在L2正则化计算时,现梯度 = (原始损失的梯度) + (λ * 权重),而Adam更新会把完整的"现梯度" 拿去做动量、平方的计算,然后用它来自适应地调整学习率,此时惩罚项λ * 权重也被调整了,这就导致权重衰减不再是"乘以衰减率",超参数λλ和Adam的自适应学习率耦合在一起。

AdamW优化

AdamW把“权重衰减”和“梯度更新”分开了,实现了超参数λλ和Adam的自适应学习率的解耦

AdamW优化器在计算梯度时,只计算来自原始损失函数的梯度。完全忽略L2正则化项。

然后两步更新参数:

  • 第一步:用Adam的方式,只用原始梯度来更新参数。

新参数=旧参数学习率动量梯度平方+ϵ新参数 = 旧参数 - \frac{学习率 * 动量}{\sqrt{梯度平方} + ϵ}

  • 第二步:另外再执行一次权重衰减,直接从参数中减去一小部分。

新参数=旧参数(学习率衰减率)旧参数新参数 = 旧参数 - (学习率 * 衰减率) * 旧参数

把这两步合起来,就是AdamW的总更新公式:

新参数=旧参数学习率(动量梯度平方+ϵ+衰减率旧参数)新参数 = 旧参数 - 学习率 * (\frac{动量} {\sqrt{梯度平方} + ϵ} + 衰减率 * 旧参数)

这个“+”号就是 AdamW 的灵魂,它意味着权重衰减是独立于自适应学习率之外的一个单独操作。

AdamW优化器更新流程

我们就举AdamW 在第 t 次迭代时的详细更新步骤:

超参数设置:

  • 学习率: η\eta
  • 权重衰减系数: λ\lambda
  • 一阶矩指数衰减率: β1\beta_1
  • 二阶矩指数衰减率: β2\beta_2
  • 数值稳定常数: ϵ\epsilon

然后是流程:先计算梯度,计算原始损失函数对参数 wt1w_{t-1} 的梯度,AdamW这里不含L2正则化项,是不同于Adam的地方。

gt=Lorig(wt1)g_t = \nabla L_{orig}(w_{t-1})

接着更新一阶矩(动量):

mt=β1mt1+(1β1)gtm_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t

然后更新二阶矩(自适应学习率项):

vt=β2vt1+(1β2)gt2v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2

然后是修正偏差项,修正一阶和二阶矩的偏差:

m^t=mt1β1t  ,  v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}\ \ ,\ \ \hat{v}_t = \frac{v_t}{1 - \beta_2^t}

偏差修正计算后,我们就代入AdamW的更新量公式(这一步和Adam一样):

AdamWupdate=ηm^tv^t+ϵAdamW_{update} = \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

最后应用解耦的权重衰减和梯度更新,这是AdamW与Adam最不一样的地方,权重衰减直接从当前权重中减去,然后才应用 Adam 的更新量。

wt=wt1AdamWupdateηλwt1w_t = w_{t-1} - AdamW_{update} - \eta \lambda w_{t-1}

补充:
有时候在深度学习中,为了让权重衰减系数 λ\lambda 的量级不依赖于学习率 η\eta,我们会变一下式,使得 λ\lambda 更容易调整:

wt=(1λ)wt1ηm^tv^t+ϵw_t = (1 - \lambda) w_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}