1723 字
9 分钟
[CS336] A1 Chapter 5 Training loop

5 Training loop#

现在我们将把迄今为止构建好的主要组件整合起来:已分词处理的数据、模型以及优化器。

5.1 Data Loader#

已分词化的数据(例如,您在 tokenizer_experiments 中准备的数据)是一个由多个 token 组成的单一序列:x=(x1,,xn)x=(x_1,\dots,x_n)。尽管原始数据可能由多个独立的文档(例如不同的网页或源代码文件)组成,但常见的做法是将所有这些内容合并成一个单一的 token 序列,并在它们之间添加一个分隔符(例如 <|endoftext|>)。

Data Loader 会将这个 token 序列转换为一系列批次,其中每个批次包含长度为 mmBB 个序列,以及与之对应的下一组 tokens,同样长度为 mm。例如,对于 B=1B=1m=3m=3,([x2x_2x3x_3x4x_4], [x3x_3x4x_4x5x_5]) 就是一个可能的批次。

以这种方式加载数据在多个方面简化了训练过程。首先,任何 1inm1 \le i \le n-m 都能给出一个有效的训练序列,因此选择训练序列非常简单。由于所有训练序列具有相同的长度,无需填充输入序列,这提高了硬件利用率(也通过增加批次大小 BB)。最后,我们也不需要加载整个数据集来采样训练数据,这使得处理可能无法存储在内存中的大型数据集变得容易。


Problem (data_loading): Implement data loading#

编写一个函数,该函数接收:

  • 一个名为 xx 的 numpy 数组(token ID 的整数数组)
  • batch_size
  • 上下文长度 context_length
  • PyTorch device string(例如 ‘cpu’‘cuda:0’

返回一个包含两个 Tensor 的 pair:采样输入序列、对应的下一个 token 目标。

这两个 Tensor 的形状都应为 (batch_size, context_length),且都应包含 token ID,并且都应置于指定的设备上。

为了根据我们提供的测试对您的实现进行测试,您首先需要在 [adapters.run_get_batch] 处实现测试适配器。然后,运行 uv run pytest -k test_get_batch 来测试您的实现。


如果数据集太大,无法全部加载到内存中怎么办?我们可以使用名为 mmap 的 Unix 系统调用,它将磁盘上的文件映射到虚拟内存中,并在访问该内存位置时延迟加载文件内容。这样,您就可以“假装”整个数据集都在内存中了。NumPy 通过 np.memmap(或者如果您最初使用 np.save 保存数组,则使用 np.load 的标志 mmap_mode='r')实现了这一点,它会返回一个类似于 NumPy 数组的对象,在您访问时按需加载条目。在训练期间从您的数据集(即 NumPy 数组)中采样时,请务必以内存映射模式加载数据集(通过 np.memmap 或根据您保存数组的方式使用标志 mmap_mode='r'np.load,同时确保还指定了与您加载的数组匹配的数据类型。

可能有必要明确验证内存映射的数据是否正确(例如,不包含超出预期词汇大小的值)。

5.2 Checkpointing#

除了加载数据之外,我们在训练过程中还需要保存模型。在运行任务时,我们通常希望能够从中途停止的训练运行中继续进行(例如,由于任务超时、机器故障等原因)。即使一切顺利,我们可能也希望以后能够访问中间模型(例如,为了事后研究训练动态、从训练的不同阶段的模型中抽取样本等)。

一个检查点应该包含我们恢复训练所需的所有状态。当然,我们至少希望能够恢复模型的权重。如果使用的是有状态的优化器(例如 AdamW),我们还需要保存优化器的状态(例如,对于 AdamW 来说,是矩估计值)。最后,为了恢复学习率调度,我们需要知道我们停止时的迭代次数。

PyTorch 使得保存所有这些都非常容易:每个 nn.Module 都有一个 state_dict() 方法,它会返回一个包含所有可学习权重的字典;我们以后可以用姐妹方法 load_state_dict() 来恢复这些权重,同样适用于任何 torch.optim.Optimizer 类型的对象。最后,torch.save(obj, dest) 可以将一个对象(例如,一个包含张量作为某些值的字典,但也有像整数这样的常规 Python 对象)保存到一个文件(路径)或类似文件的对象中,然后可以使用 torch.load(src) 将其加载回内存中。


Problem (checkpointing): Implement model checkpointing#

实现以下两个函数以加载和保存检查点:

def save_checkpoint(model, optimizer, iteration, out):
	pass
def load_checkpoint(src, model, optimizer):
	pass
  • save_checkpoint 应将模型、优化器和迭代过程中的所有状态都存入一个类似文件的对象 out 中。您可以使用模型和优化器的 state_dict 方法获取它们的相关状态,并使用 torch.save(obj, out) 将 obj 保存到 out 中(PyTorch 在这里支持路径或类似文件的对象)。通常的选择是让 obj 是一个字典,但您可以使用任何您想要的格式,只要您之后能够加载您的检查点即可。

    该函数接收以下参数:

    • model: torch.nn.Module
    • optimizer: torch.optim.Optimizer
    • iteration: int
    • out: str | os.Pathlike | typing.BinaryIO | typing.IO[bytes]
  • load_checkpoint 应从 src(路径或类似文件的对象)加载一个检查点,然后从该检查点恢复模型和优化器的状态。您的函数应返回保存在检查点中的迭代次数。您可以使用 torch.load(src) 来恢复您在 save_checkpoint 实现中保存的内容,并且在模型和优化器中使用 load_state_dict 方法将它们恢复到之前的状态。

    该函数接收以下参数:

    • src: str | os.PathLike | typing.BinaryIO | typing.IO[bytes]
    • model: torch.nn.Module
    • optimizer: torch.optim.Optimizer

实现 [adapters.run_save_checkpoint][adapters.run_load_checkpoint] 这两个适配器,并确保它们能够通过测试 uv run pytest -k test_checkpointing


5.3 Training loop#

现在,终于到了将你所实现的所有组件整合到主训练脚本中的时候了。这样做会很有好处,因为这样便于以不同的超参数(例如,将其作为命令行参数)来开始训练运行,因为之后你还会多次进行这样的操作,以研究不同的选择对训练过程的影响。


Problem (training_together): Put it together#

编写一个脚本,通过训练循环来使用用户提供的输入对模型进行训练。具体而言,我们建议您的训练脚本应至少具备以下功能:

  • 能够配置和控制各种模型和优化器的超参数
  • 使用 np.memmap 实现对大型训练和验证数据集的高效加载
  • 将检查点序列化到用户指定的路径
  • 定期记录训练和验证性能(例如,输出到控制台和/或外部服务,如 Weights and Biases)

[CS336] A1 Chapter 5 Training loop
https://lettle.cn/posts/5-training-loop/
作者
Lettle
发布于
2026-05-26
许可协议
CC BY-NC-SA 4.0