5 Training loop
现在我们将把迄今为止构建好的主要组件整合起来:已分词处理的数据、模型以及优化器。
5.1 Data Loader
已分词化的数据(例如,您在 tokenizer_experiments 中准备的数据)是一个由多个 token 组成的单一序列:。尽管原始数据可能由多个独立的文档(例如不同的网页或源代码文件)组成,但常见的做法是将所有这些内容合并成一个单一的 token 序列,并在它们之间添加一个分隔符(例如 <|endoftext|>)。
Data Loader 会将这个 token 序列转换为一系列批次,其中每个批次包含长度为 的 个序列,以及与之对应的下一组 tokens,同样长度为 。例如,对于 ,,([, , ], [, , ]) 就是一个可能的批次。
以这种方式加载数据在多个方面简化了训练过程。首先,任何 都能给出一个有效的训练序列,因此选择训练序列非常简单。由于所有训练序列具有相同的长度,无需填充输入序列,这提高了硬件利用率(也通过增加批次大小 )。最后,我们也不需要加载整个数据集来采样训练数据,这使得处理可能无法存储在内存中的大型数据集变得容易。
Problem (data_loading): Implement data loading
编写一个函数,该函数接收:
- 一个名为 的 numpy 数组(token ID 的整数数组)
- batch_size
- 上下文长度 context_length
- PyTorch device string(例如 ‘cpu’ 或 ‘cuda:0’)
返回一个包含两个 Tensor 的 pair:采样输入序列、对应的下一个 token 目标。
这两个 Tensor 的形状都应为 (batch_size, context_length),且都应包含 token ID,并且都应置于指定的设备上。
为了根据我们提供的测试对您的实现进行测试,您首先需要在 [adapters.run_get_batch] 处实现测试适配器。然后,运行 uv run pytest -k test_get_batch 来测试您的实现。
如果数据集太大,无法全部加载到内存中怎么办?我们可以使用名为 mmap 的 Unix 系统调用,它将磁盘上的文件映射到虚拟内存中,并在访问该内存位置时延迟加载文件内容。这样,您就可以“假装”整个数据集都在内存中了。NumPy 通过 np.memmap(或者如果您最初使用 np.save 保存数组,则使用 np.load 的标志 mmap_mode='r')实现了这一点,它会返回一个类似于 NumPy 数组的对象,在您访问时按需加载条目。在训练期间从您的数据集(即 NumPy 数组)中采样时,请务必以内存映射模式加载数据集(通过 np.memmap 或根据您保存数组的方式使用标志 mmap_mode='r' 的 np.load),同时确保还指定了与您加载的数组匹配的数据类型。
可能有必要明确验证内存映射的数据是否正确(例如,不包含超出预期词汇大小的值)。
5.2 Checkpointing
除了加载数据之外,我们在训练过程中还需要保存模型。在运行任务时,我们通常希望能够从中途停止的训练运行中继续进行(例如,由于任务超时、机器故障等原因)。即使一切顺利,我们可能也希望以后能够访问中间模型(例如,为了事后研究训练动态、从训练的不同阶段的模型中抽取样本等)。
一个检查点应该包含我们恢复训练所需的所有状态。当然,我们至少希望能够恢复模型的权重。如果使用的是有状态的优化器(例如 AdamW),我们还需要保存优化器的状态(例如,对于 AdamW 来说,是矩估计值)。最后,为了恢复学习率调度,我们需要知道我们停止时的迭代次数。
PyTorch 使得保存所有这些都非常容易:每个 nn.Module 都有一个 state_dict() 方法,它会返回一个包含所有可学习权重的字典;我们以后可以用姐妹方法 load_state_dict() 来恢复这些权重,同样适用于任何 torch.optim.Optimizer 类型的对象。最后,torch.save(obj, dest) 可以将一个对象(例如,一个包含张量作为某些值的字典,但也有像整数这样的常规 Python 对象)保存到一个文件(路径)或类似文件的对象中,然后可以使用 torch.load(src) 将其加载回内存中。
Problem (checkpointing): Implement model checkpointing
实现以下两个函数以加载和保存检查点:
def save_checkpoint(model, optimizer, iteration, out):
pass
def load_checkpoint(src, model, optimizer):
passsave_checkpoint 应将模型、优化器和迭代过程中的所有状态都存入一个类似文件的对象 out 中。您可以使用模型和优化器的
state_dict方法获取它们的相关状态,并使用torch.save(obj, out)将 obj 保存到 out 中(PyTorch 在这里支持路径或类似文件的对象)。通常的选择是让 obj 是一个字典,但您可以使用任何您想要的格式,只要您之后能够加载您的检查点即可。该函数接收以下参数:
- model:
torch.nn.Module - optimizer:
torch.optim.Optimizer - iteration:
int - out:
str | os.Pathlike | typing.BinaryIO | typing.IO[bytes]
- model:
load_checkpoint 应从 src(路径或类似文件的对象)加载一个检查点,然后从该检查点恢复模型和优化器的状态。您的函数应返回保存在检查点中的迭代次数。您可以使用
torch.load(src)来恢复您在save_checkpoint实现中保存的内容,并且在模型和优化器中使用load_state_dict方法将它们恢复到之前的状态。该函数接收以下参数:
- src:
str | os.PathLike | typing.BinaryIO | typing.IO[bytes] - model:
torch.nn.Module - optimizer:
torch.optim.Optimizer
- src:
实现 [adapters.run_save_checkpoint] 和 [adapters.run_load_checkpoint] 这两个适配器,并确保它们能够通过测试 uv run pytest -k test_checkpointing
5.3 Training loop
现在,终于到了将你所实现的所有组件整合到主训练脚本中的时候了。这样做会很有好处,因为这样便于以不同的超参数(例如,将其作为命令行参数)来开始训练运行,因为之后你还会多次进行这样的操作,以研究不同的选择对训练过程的影响。
Problem (training_together): Put it together
编写一个脚本,通过训练循环来使用用户提供的输入对模型进行训练。具体而言,我们建议您的训练脚本应至少具备以下功能:
- 能够配置和控制各种模型和优化器的超参数
- 使用
np.memmap实现对大型训练和验证数据集的高效加载 - 将检查点序列化到用户指定的路径
- 定期记录训练和验证性能(例如,输出到控制台和/或外部服务,如 Weights and Biases)
