9 行代码移除归一化层,Transformer性能不降反升?解密DyT

B站影视 日本电影 2025-04-02 07:06 1

摘要:在深度学习,特别是大火的transformer模型(比如ChatGPT、Stable Diffusion背后的技术)中,Normalization(归一化)层,尤其是Layer Normalization (LN),几乎是无处不在的“标配”。它们就像模型训练过

在深度学习,特别是大火的transformer模型(比如ChatGPT、Stable Diffusion背后的技术)中,Normalization(归一化)层,尤其是Layer Normalization (LN),几乎是无处不在的“标配”。它们就像模型训练过程中的“稳定器”,被认为对于模型的稳定收敛和高性能至关重要。数据归一化层作为模型重要的一个组件,貌似是无法替代的存在,而大家也习惯了这样的做法,无论是前置归一化层还是后置归一化层,任何一个基于 transformer 的LLM 大语言模型都无一例外的添加了归一化层。

transformer 模型框架

但是,如果我告诉你,这个“标配”可能不是必需的呢?来自Meta AI、纽约大学、MIT和普林斯顿大学的顶尖研究者们(包括Kaiming He、Yann LeCun等)最近发表了一篇论文,提出了一种极其简单的方法,让Transformer模型在不需要任何Normalization层的情况下,依然能达到甚至超越原有性能!

自从2017年Google提出"Attention Is All You Need"以来,Transformer架构就凭借其强大的并行处理能力和捕捉长距离依赖关系的能力,彻底改变了自然语言处理(NLP)领域,并迅速扩展到计算机视觉(CV)、语音识别等多个领域。

transformer 模型

一个典型的Transformer Block通常包含几个核心组件:

自注意力机制 (Self-Attention): 这是Transformer的灵魂,让模型能够动态地关注输入序列中不同部分的重要性。前馈神经网络 (Feed-forward Network, FFN): 对注意力机制处理后的信息进行进一步的非线性变换。残差连接 (Residual Connections): 将输入直接加到输出上,帮助缓解梯度消失问题,让网络可以做得更深。层归一化 (Layer Normalization, LN): 通常在自注意力或FFN之后、残差连接之前或之后使用,用于稳定训练过程中的激活值。

几乎所有主流的Transformer模型,从BERT到GPT系列,再到ViT(Vision Transformer),都严格遵循着这个包含LN的结构。

那么,LN究竟是做什么的?

简单来说,Layer Normalization是在每个样本的每个时间步(或token)上,独立地对特征维度(embedding dimension)进行归一化。它计算该token所有特征的均值和方差,然后用这些统计量来标准化这些特征,最后再通过两个可学习的参数(缩放因子γ和偏移因子β)进行仿射变换。

LN的主要作用是:

稳定训练: 归一化后的激活值分布更稳定,减少了内部协变量偏移(Internal Covariate Shift)的影响,让训练过程更平稳,可以使用更高的学习率。加速收敛: 训练更稳定通常意味着收敛更快。提升性能: 特别是在深层网络中,LN有助于梯度传播,使得训练更深、更复杂的模型成为可能。

因为这些显著的优点,LN(以及它的变种如RMSNorm)成为了Transformer架构中雷打不动的组件。以至于近年的研究大多在尝试改进Attention或FFN,却很少有人去质疑Normalization层的必要性。

layer norm 输入输出关系

Transformers without Normalization这篇论文的研究者们做了一个有趣的观察实验。他们可视化了训练好的Transformer模型中LN层的输入和输出关系。结果惊人地发现:LN层的实际作用,很多时候表现得非常像一个缩放版的tanh(双曲正切)函数!

LN层的输入输出映射呈现出明显的S形曲线。对于接近零的输入值,输出几乎是线性的;但对于绝对值较大的“极端”输入值,LN层会将其“压扁”到一个较小的范围内,这和tanh函数的特性非常相似。

tanh函数

这个发现启发了研究者:LN层的关键作用可能并非精确计算均值方差进行标准化,而是在于对输入进行适当缩放,并通过类似tanh的饱和特性来“压制”极端激活值,从而稳定网络。

基于这个洞察,他们提出了一个极其简单的替代方案,命名为Dynamic Tanh (DyT)。它的计算公式是:

DyT(x) = γ * tanh(α * x) + β

这里的x是输入张量。关键在于:

tanh(α * x): 使用tanh函数来模拟LN的S形映射和压制极端值的效果。α (alpha): 这是一个可学习的标量参数。它动态地学习一个合适的缩放因子,来调整tanh函数作用的范围,间接模拟了LN根据输入数据调整尺度的能力。这就是“Dynamic”的由来。γ (gamma) 和 β (beta): 这两个是可学习的向量参数(维度与特征维度C相同),与LN中的仿射变换参数完全一样,允许输出被重新缩放和平移到任意范围。

DyT

直接将原始Transformer模型中所有的LN(或RMSNorm)层,替换成DyT层即可。不需要计算均值方差,每个元素独立计算,实现起来非常简单。而代码也是只需要 9 行代码,直接代替 transformer 模型中的 layer norm 数据归一化层的代码即可。

class DyT(Module): def __init__(self, C, init_α): super.__init__ self.α = Parameter(ones(1) * init_α) self.γ = Parameter(ones(C)) self.β = Parameter(zeros(C)) def forward(self, x): x = tanh(self.alpha * x) return self.γ * x + self.β

研究者们在各种任务和模型上验证了DyT的效果,结果令人印象深刻:

性能持平或更好: 无论是在图像分类(ViT, ConvNeXt)、图像生成(Diffusion Transformer)、自监督学习(MAE, DINO)、大语言模型(LLaMA)、语音处理(wav2vec 2.0)还是DNA序列建模(HyenaDNA, Caduceus)等多种任务上,使用DyT替换LN/RMSNorm后的模型,性能几乎都能达到甚至略微超过原始模型,而且很多时候不需要调整原始模型的超参数!

训练稳定性: DyT模型的训练损失曲线与LN模型高度相似,表明DyT也能提供足够的训练稳定性。

计算效率提升: 由于DyT是简单的element-wise(逐元素)操作,避免了计算均值方差的开销,其计算速度比LN/RMSNorm更快。在LLaMA 7B模型的测试中,DyT层本身的推理和训练延迟分别减少了约52%和42%,带动整个模型的延迟也有所下降。

组件重要性分析: 实验证明,tanh函数的非线性压制作用和可学习的α都至关重要。如果去掉tanh(只用线性变换)会导致训练崩溃;如果去掉可学习的α(固定α=1),性能会下降。

优于其他无Normalization方法: 与Fixup、SkipInit等其他尝试移除Normalization的方法相比,DyT在同等条件下表现更优。

α的初始化: 对于大多数模型,默认将α初始化为0.5效果就很好。但在大型语言模型(LLM)中,精细调整α的初始值(特别是对Attention块和FFN块使用不同的初始值)可以带来显著的性能提升。

这项工作有力地证明了,在强大的Transformer架构中,我们长期依赖的Normalization层可能并非不可或缺。Dynamic Tanh (DyT)作为一种极其简洁的替代方案,不仅能够复现甚至超越LN的效果,还带来了计算效率上的优势。

transformer 模型作为大语言模型的核心框架,一直以来由于其性能壁垒,研究人员一直在寻找可替代方法,以及可以优化的方法,包含 flash attention,稀疏注意力机制,MoE混合专家模型,潜在注意力机制,Mamba 等线性模型等等,但是无论如何改进,其注意力机制是 transformer 模型的核心算法。是否有可以取代 transformer 注意力机制的算法,我们拭目以待。

来源:人工智能研究所

相关推荐