3095 字
15 分钟
[CS336] A1 Chapter 4 Training a Transformer LM

4 Training a Transformer LM#

现在我们已经掌握了预处理数据(通过 Tokenizer)和模型(Transformer)的步骤。接下来要做的就是构建所有支持训练的代码。这包括以下内容:

  • 损失函数:我们需要定义损失函数(cross-entropy)
  • 优化器:我们需要定义优化器以最小化此损失(AdamW)
  • 训练循环:我们需要所有基础设施,包括加载数据、保存检查点以及管理训练过程

4.1 Cross-entropy loss#

回想一下,Transformer 语言模型为每个序列 xx 定义了一个分布 pθ(xi+1x1:i)p_{\theta}(x_{i+1} | x_{1:i}),长度为 m+1m+1,且i=1,,mi=1,\dots,m。给定一个由长度为 m+1m+1 的序列组成的训练集 DD,我们定义标准交叉熵(负对数似然)损失函数:

l(θ;D)=1DmxDi=1mlogpθ(xi+1x1:i)(16)\mathscr{l} (\theta; D) = \frac{1}{|D|m} \sum_{x \in D} \sum_{i=1}^{m} -log p_\theta (x_{i+1} | x_{1:i}) \tag{16}

(注意,在Transformer中,单次前向传播会为所有 i=1,,mi=1,\dots,m 生成 pθ(xi+1x1:i)p_\theta(x_{i+1} | x_{1:i})。)

具体来说,Transformer 为每个位置 ii 计算对数几率 oiRvocab_sizeo_i \in \mathbb{R}^{vocab\_size} ,结果为:

p(xi+1x1:i)=softmax(oi)[xi+1]=exp(oi[xi+1])a=1vocab_sizeexp(oi[a])(17)p(x_{i+1} | x_{1:i}) = softmax(o_i)[x_{i+1}] = \frac{exp(o_i[x_{i+1}])} {\sum^{vocab\_size}_{a=1} exp(o_i[a])} \tag{17}

交叉熵损失通常是以向量形式的对数输出值 oiRvocab_sizeo_i \in \mathbb{R}^{vocab\_size} 以及目标值 xi+1x_{i+1} 为依据来定义的。

在使用交叉熵损失时,同样需要小心处理数值相关的问题,这一点与使用 softmax 函数时的情况是相同的。


Problem (cross_entropy): Implement cross-entropy#

编写一个函数来计算交叉熵损失,该函数接收预测的对数概率(oio_i)和目标值(xi+1x_{i+1}),并计算交叉熵损失值 li=log softmax(oi)[xi+1]\mathscr{l}_i = -log\ softmax(o_i)[x_{i+1}]。你的函数应执行以下操作:

  • 为确保数值稳定性,需减去最大值。
  • 尽可能消除对数和指数运算。
  • 处理任何额外的批次维度,并计算批次内的平均值。如同第 3.2 节中,我们假设批处理式的维度总是排在词汇大小维度之前。

实现 [adapters.run_cross_entropy] ,然后运行 uv run pytest -k test_cross_entropy 来测试您的实现。


困惑度

交叉熵对于训练来说已经足够了,但在评估模型时,我们还希望报告困惑度。对于长度为 𝑚 的序列,在遭受交叉熵损失 l1,,lm\mathscr{l}_1,\dots,\mathscr{l}_m 的情况下:

perplexity=exp(1mi=1mli)(18)perplexity = exp(\frac{1}{m} \sum^{m}_{i=1} \mathscr{l}_i) \tag{18}

4.2 SGD Optimizer#

既然我们已经有了损失函数,接下来我们就来探讨优化器。最简单的基于梯度的优化器是随机梯度下降(SGD)。我们从随机初始化的参数 θ0\theta_0 开始,然后对于每一步 t=0,,T1t=0,\dots,T-1,我们执行以下更新操作:

θt+1θtαtL(θt;Bt)(19)\theta_{t+1} \leftarrow \theta_t - \alpha_t \nabla L(\theta_t ; B_t) \tag{19}

其中,BtB_t 是从数据集 DD 中随机抽取的一组数据,而学习率 αt\alpha_t批次大小 Bt|B_t| 则是超参数。

4.2.1 Implementing SGD in PyTorch#

为了实现我们的优化器,我们将对 PyTorch 的 torch.optim.Optimizer 类进行子类化。一个优化器子类必须实现两个方法:

def __init__(self, params, ...):
	pass

def step(self):
    pass
  • __init__() 应当初始化您的优化器。在此,params 将是一个要进行优化的参数集合(或者在用户希望为模型的不同部分使用不同的超参数(如学习率)的情况下,会是参数组)。请务必将 params 传递给基类的 __init__ 方法,该方法会将这些参数存储起来以便在 step 中使用。您可以根据优化器的不同情况添加其他参数(例如,学习率是一个常见的参数),并将它们作为字典的形式传递给基类构造函数,其中 keys 是您为这些参数所选择的名称(字符串)
  • step() 应当对参数进行一次更新。在训练循环中,这将在反向传播之后被调用,因此您能够获取最后一批数据的梯度。此方法应遍历每个参数张量 p 并对其进行原地修改,即设置 p.data,它包含了与该参数相关联的张量,该张量基于梯度 p.grad(如果存在的话)来设置,即基于表示该参数相对于损失的梯度的张量。

PyTorch 的优化器 API 存在一些细微之处,因此通过一个示例来解释会更易于理解。为了使我们的示例更加丰富,我们将实现一种与随机梯度下降(SGD)略有不同的变体,其中学习率会在训练过程中逐渐降低,初始学习率为 α\alpha,随着时间的推移会逐步减小步长:

θt+1=θtαt+1L(θt;Bt)(20)\theta_{t+1} = \theta_t - \frac{\alpha} {\sqrt{t+1}} \nabla L(\theta_t ; B_t) \tag{20}

让我们来看看这种版本的随机梯度下降算法是如何被实现为 PyTorch Optimizer 的:

from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
    	if lr < 0:
    		raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)
        
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate.
            for p in group["params"]:
                if p.grad is None:
                	continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or 0.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1 # Increment iteration number.
        return loss

__init__ 方法中,我们将参数以及默认的超参数传递给基类的构造函数(这些参数可能以组的形式出现,每组具有不同的超参数)。如果参数只是单个包含 torch.nn.Parameter 对象的集合,基类构造函数将创建一个单一的组,并为其分配默认超参数。然后,在 step 方法中,我们遍历每个参数组,再遍历该组中的每个参数,并应用 公式 20。这里,我们将迭代次数作为与每个参数关联的状态保存下来:我们首先读取这个值,在梯度更新中使用它,然后进行更新。该 API 规定用户可以传递一个可调用闭包来在优化器步骤之前重新计算损失。对于我们将要使用的优化器,我们不需要这个功能,但为了符合 API 规定,我们还是添加了它。

为了查看其工作原理,我们可以使用以下训练循环的最小示例:

weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD(weights, lr=1)

for t in range(100):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
	loss.backward() # Run backward pass, which computes gradients.
	opt.step() # Run optimizer step.

这就是训练循环的典型结构:在每次迭代中,我们会计算损失并执行优化器的一个 step。在训练语言模型时,可学习的参数将来自模型(在 PyTorch 中,m.parameters() 可以给我们提供这个集合)。损失将基于一个采样批次的数据进行计算,但训练循环的基本结构是相同的。


Problem (learning_rate_tuning): Tuning the learning rate#

正如我们将会看到的那样,影响训练效果最显著的一个超参数就是学习率。 让我们通过我们的示例来实际看看这一点。使用上述的随机梯度下降示例,并将学习率设置为另外三个值:1e1、1e2 和 1e3,仅进行 10 次训练迭代。对于这些不同的学习率,损失会怎样变化呢?它会更快地降低、更慢地降低,还是会发散(即在训练过程中不断增加)?


4.3 AdamW#

现代语言模型通常采用更为复杂的优化器进行训练,而非随机梯度下降(SGD)。近期使用的大多数优化器都是 Adam 优化器的衍生版本。我们将使用 AdamW,它在近期的研究中被广泛应用。AdamW 提出了对 Adam 的改进,通过添加权重衰减(在每次迭代中,我们将参数拉向 0),从而改进了正则化效果,这种方式与梯度更新是解耦的。我们将按照 I. Loshchilov 等人 算法 2 中的描述来实现 AdamW。

AdamW 是有状态的:对于每个参数,它会记录其一、二阶矩的运行估计值。因此,AdamW 会使用更多的内存以换取更好的稳定性和收敛性。除了学习率 α\alpha 之外,AdamW 还有两个超参数(β1,β2\beta_1,\beta_2)来控制矩估计值的更新,并有一个权重衰减率 λ\lambda。典型的应用将(β1,β2\beta_1,\beta_2)设置为(0.9,0.999),但像 LLaMA 和 GPT-3 这样的大型语言模型通常使用(0.9,0.95)作为训练参数。该算法可以如下编写,其中 ϵ\epsilon 是一个小值(例如 10810^{-8})用于在 vv 中出现极小值时提高数值稳定性:


Algorithm 1: AdamW Optimizer

  1. init(θ\theta) ▷ 初始化可学习参数
  2. m0m \leftarrow 0 ▷ 一阶矩向量的初始值;与 θ\theta 具有相同形状
  3. v0v \leftarrow 0 ▷ 二阶矩向量的初始值;与 θ\theta 具有相同形状
  4. for t=1,,Tt=1,\dots,T do
  5. ​ 选取数据批次 BtB_t
  6. gθl(θ;Bt)g \leftarrow \nabla_\theta \mathscr{l}(\theta ; B_t) ▷ 计算损失函数的梯度
  7. αtα1β2t1β1t\alpha_t \leftarrow \alpha \frac{\sqrt{1- \beta_2^t}}{1-\beta_1^t} ▷ 计算第 tt 次迭代的调整后的 α\alpha
  8. θθαλθ\theta \leftarrow \theta - \alpha\lambda\theta ▷ 应用权重衰减
  9. mβ1m+(1β1)gm \leftarrow \beta_1m+(1-\beta_1)g ▷ 更新一阶矩估计
  10. vβ2v+(1β2)g2v \leftarrow \beta_2v+(1-\beta_2)g^2 ▷ 更新二阶矩估计
  11. θθαtmv+ϵ\theta \leftarrow \theta - \alpha_t\frac{m}{\sqrt{v}+\epsilon} ▷ 应用 moment-adjusted 权重更新
  12. end for

请注意,变量 tt 从 1 开始,现在您需要实现这个优化器。


Problem (adamw): Implement AdamW#

将 AdamW 优化器实现为 torch.optim.Optimizer 类的一个子类。您的类在初始化时应设置学习率 α\alpha 以及超参数 β\betaϵ\epsilonλ\lambda 。为了帮助您保存状态,基础的 Optimizer 类为您提供了一个名为 self.state 的字典,该字典将 nn.Parameter 对象与一个存储该参数所需任何信息的字典进行映射(对于 AdamW 来说,这将是矩估计)。请实现 [adapters.get_adamw_cls] 并确保它能通过 uv run pytest -k test_adamw 测试。


4.4 Learning rate scheduling#

在训练过程中,能使损失值下降最快的学习率值往往会有所变化。在训练 Transformer 模型时,通常会采用学习率调度策略,即一开始使用较大的学习率,以便在初始阶段进行较快的更新,随后随着模型训练的进行,逐渐将其降低至较小的值。在本次任务中,我们将实现用于训练 LLaMA 的余弦退火调度策略(cosine annealing schedule)。

调度器其实就是一个函数,它会接收当前 step tt 以及其他相关参数(比如初始学习率和最终学习率),并返回在 step tt 时用于梯度更新的学习率。最简单的学习率调度方式是常函数,它会根据任何的 tt 值返回相同的学习率。

余弦退火学习率调度(cosine annealing schedule)需要:

(i)当前迭代次数 tt

(ii)最大学习率 αmax\alpha_{max}

(iii)最小(最终)学习率 αmin\alpha_{min}

(iv)warm-up 迭代次数 TwT_w

(v)余弦退火的最终迭代次数 TcT_c

在迭代次数为 tt 时的学习率定义为:

(Warm-up)t<Twt<T_w 时,αt=tTwαmax\alpha_t = \frac{t}{T_w} \alpha_{max}

(Cosine annealing)TwtTcT_w \le t \le T_c 时,αt=αmin+12(1+cos(tTwTcTwπ))(αmaxαmin)\alpha_t = \alpha_{min} + \frac{1}{2}(1+cos(\frac{t-T_w}{T_c-T_w}\pi))(\alpha_{max}-\alpha_{min})

(Post-annealing)t>Tct>T_c 时, αt=αmin\alpha_t = \alpha_{min}


Problem (learning_rate_schedule): Implement cosine learning rate schedule with warmup#

编写一个函数,该函数接收参数 ttαmax\alpha_{max}αmin\alpha_{min}TwT_wTcT_c,并根据上述定义的调度器返回学习率 αt\alpha_t 。然后实现 [adapters.get_lr_cosine_schedule] 并确保其能通过 uv run pytest -k test_get_lr_cosine_schedule


4.5 Gradient clipping#

在训练过程中,有时会遇到一些训练样本会产生较大的梯度,这可能会导致训练过程不稳定。为解决这一问题,实践中经常采用的一种技术是梯度裁剪。其原理是在每次反向传播后,对梯度的范数进行限制,然后再进行优化器步骤。 对于所有参数的梯度 gg ,我们计算其 l2\mathscr{l}_2 范数 g2||g||_2 。如果该范数小于最大值 MM ,则保持 gg 不变;否则,我们将 gg 的值乘以因子 Mg2+ϵ\frac{M}{||g||_2+\epsilon}(其中添加一个小的 ϵ\epsilon,例如 10610^{-6} 以保证数值稳定性)。需要注意的是,得到的范数将略小于 MM


Problem (gradient_clipping): Implement gradient clipping#

编写一个函数来实现梯度裁剪。您的函数应接受一个参数列表和一个最大 l2\mathscr{l}_2 范数。它应直接修改每个参数的梯度。使用 ϵ=106\epsilon=10^{-6}(这是 PyTorch 的默认值)。然后,实现适配器 [adapters.run_gradient_clipping] 并确保它通过 uv run pytest -k test_gradient_clipping


[CS336] A1 Chapter 4 Training a Transformer LM
https://lettle.cn/posts/4-training-a-transformer-lm/
作者
Lettle
发布于
2026-05-25
许可协议
CC BY-NC-SA 4.0