背景

Title: BitNet: Scaling 1-bit Transformers for Large Language Models

Author: Hongyu Wang, Shuming Ma, Li Dong, Shaohan Huang, Huaijie Wang, Lingxiao Ma, Fan Yang, Ruiping Wang, Yi Wu, Furu Wei

Publish: arxiv

Year: 2023

Code: https://github.com/microsoft/BitNet

Keyword: 量化,Transformer,大模型

信息

现有的大部分量化方法都是后训练的,然而,尤其是在准确率较低的情况下,它们的性能不好

即在训练时没有为量化表示优化,有一些像大模型没有冗余和留有余力

另一种是是quantization-aware training,可以取得更好的性能,可以继续训练或者微调,但难以训练,并且是否遵循scaling law

以往对于有一些研究,但:

  • 聚焦于CNN
  • 聚焦于双向自注意力

bert的encoder-decoder等模型没有展现出scaling law

贡献:

  • 提出了BitNet,一个1bit transformer结构,在训练时,权重和激活函数使用低精度,优化器状态和梯度使用高精度
  • 可以scale且具有较高的稳定性
  • 模型相对简单,对于nn.Linear进行重构,并且,引入了诸如PagedAttention,FalshAttention,speculative decoding

方法

Screenshot 2025-04-28 at 19.40.25

BitNet结构与transformer类似,只是使用BitLinear改写了所有线性层,保留了其它为高精度:

  • 残差连接和layer normlization对于LLM来说计算量很少
  • QKV的计算比参数映射少很多
  • 保留了输入/输出,来进行采样

BitLinear

权重二值化

首先用符号函数sig将权重二值化,具体来说:

W~=Sign(Wα),Sign(Wij)={+1, if Wij>0,1, if Wij0,α=1nmijWij\begin{aligned} &\widetilde{W}=\operatorname{Sign}(W-\alpha),\\ &\begin{aligned} \operatorname{Sign}\left(W_{i j}\right) & = \begin{cases}+1, & \text { if } W_{i j}>0, \\ -1, & \text { if } W_{i j} \leq 0,\end{cases} \\ \alpha & =\frac{1}{n m} \sum_{i j} W_{i j} \end{aligned} \end{aligned}

  • 将权重中心化为均值为0,即减α\alpha,尽量平衡容量
  • 进行二值化
  • 使用一个缩放系数β\beta来缩小真值和二值化后的L2损失
输入b-bit化

作者称这一步为激活?

使用最大值缩放,即将激活值缩放到[Qb,Qb](Qb=2b1)[-Q_b, Q_b](Q_b=2^{b-1})之间:

x~=Quant(x)=Clip(x×Qbγ,Qb+ϵ,Qbϵ),Clip(x,a,b)=max(a,min(b,x)),γ=x,\begin{aligned} &\widetilde{x}=\operatorname{Quant}(x)=\operatorname{Clip}\left(x \times \frac{Q_b}{\gamma},-Q_b+\epsilon, Q_b-\epsilon\right),\\ &\operatorname{Clip}(x, a, b)=\max (a, \min (b, x)), \quad \gamma=\|x\|_{\infty}, \end{aligned}

其中,ϵ\epsilon是防止Clip时溢出的极小值

特别的,对于在非线性激活层之前的,通过减去均值将其缩放到[0,Qb][0, Q_b]之间:

x~=Quant(x)=Clip((xη)×Qbγ,ϵ,Qbϵ),η=minijxij.\widetilde{x}=\operatorname{Quant}(x)=\operatorname{Clip}\left((x-\eta) \times \frac{Q_b}{\gamma}, \epsilon, Q_b-\epsilon\right), \quad \eta=\min _{i j} x_{i j} .

TODO: why

一些细节:b=8,且训练时针对每个tensor,推理时针对每个token

归一化

由此,BitLinear可以写成:

y=W~x~y=\widetilde{W} \widetilde{x}

假设W和xx中的元素是独立同分布的,且W和xx相互独立,那么yy的方差就是:

Var(y)=nVar(w~x~)=nE[w~2]E[x~2]=nβ2E[x~2]E[x~2]\begin{aligned} \operatorname{Var}(y) & =n \operatorname{Var}(\widetilde{w} \widetilde{x}) \\ & =n E\left[\widetilde{w}^2\right] E\left[\widetilde{x}^2\right] \\ & =n \beta^2 E\left[\widetilde{x}^2\right] \approx E\left[\widetilde{x}^2\right] \end{aligned}

对于全精度训练,由于标准初始化(kaiming init, Xavier init等)会把输出的方差缩放到1

为了在量化之后方差还为1,量化之前加入了LayerNorm

总体

总体公式就为:

y=W~x~=W~Quant(LN(x))×βγQbLN(x)=xE(x)Var(x)+ϵ,β=1nmW1\begin{gathered} y=\widetilde{W} \widetilde{x}=\widetilde{W} \operatorname{Quant}(\operatorname{LN}(x)) \times \frac{\beta \gamma}{Q_b} \\ \operatorname{LN}(x)=\frac{x-E(x)}{\sqrt{\operatorname{Var}(x)+\epsilon}}, \quad \beta=\frac{1}{n m}\|W\|_1 \end{gathered}

组并行化

现有的模型并行化都是基于在切分维度上张量是独立的,但这里的α\alphaβ\betaγ\gamma,均值,方差等是由整个张量计算

尽管可以使用all-reduce的方案,但同步过程太大

将权重和激活值分成不同的组,形成Group Quantization,一个权重分为GG组,一组有nG×m\frac{n}{G} \times m个参数,然后在组间并行:

αg=GnmijWij(g),βg=GnmW(g)1\alpha_g=\frac{G}{n m} \sum_{i j} W_{i j}^{(g)}, \quad \beta_g=\frac{G}{n m}\left\|W^{(g)}\right\|_1

同样,也对于输入分为GG组:

γg=x(g),ηg=minijxij(g)\gamma_g=\left\|x^{(g)}\right\|_{\infty}, \quad \eta_g=\min _{i j} x_{i j}^{(g)}

而LN也是:

LN(x(g))=x(g)E(x(g))Var(x(g))+ϵ\mathrm{LN}\left(x^{(g)}\right)=\frac{x^{(g)}-E\left(x^{(g)}\right)}{\sqrt{\operatorname{Var}\left(x^{(g)}\right)+\epsilon}}

训练

Straight-through estimator

使用STE估计越过不可微的部分(符号函数,Clip等)

混合精度训练

梯度和优化器状态仍然使用高精度来保证稳定性和准确度,并为高精度的可学习参数维护了一个潜在状态来累计参数更新,其在训练时为二值,在推理时不使用

大学习率

通常,对于潜在权重的小的更新通常不起作用

大学习率对于二值网络收敛非常重要

计算效率

不同节点乘加的效率:

image-20250429100938696

对于原始transformer,其能量消耗是:

Eadd=m×(n1)×p×E^addEmul=m×n×p×E^mul\begin{aligned} E_{a d d} & =m \times(n-1) \times p \times \hat{E}_{a d d} \\ E_{m u l} & =m \times n \times p \times \hat{E}_{m u l} \end{aligned}

对于BitNet,由于是1bit,只有1或-1,主要运算为加法,能量消耗为:

Emul=(m×p+m×n)×E^mulE_{m u l}=(m \times p+m \times n) \times \hat{E}_{m u l}

实验

与FP16 transformer对比

实验设定

训练了不同大小的BitNet(从125M到30B)

数据集为英文资料文集,包括Pile,Common Crawl snapshots,RealNews,CC-Stories

使用了Sentencpiece序列化器,词表大小为16K

推理最佳scaling law

大致符合公式:

L(N)=aNb+cL(N)=a N^b+c

image-20250429102312393

以往的计算量计算方法不适合于BitNet,提出了Inference-Optimal Scaling Law,是用能量来计算

下游任务

使用了0-shot和4-shot

image-20250429110612808

稳定性

image-20250429110721798

与后训练量化对比

image-20250904130058617

消融实验

image-20250904130206465

总结

探究了如何从头训练一个量化的网络,提出了构建Transformer的基础bit结构:

  • 构建了稳定的bit线性层计算和训练,模型可以在此基础上进行拓展
  • 证明了bit结构下,可以在性能不损失过多的情况下,显著减少计算资源消耗
  • 证明bit结构下,模型也可以符合scaling law

然而:

  • 实验最大参数量只有8B,更多参数量性能为预估值
  • 下游任务相对简单