跳转至

如何用1024张显卡训练一个模型

最近看到知乎一个回答,把千卡训练的难度吹上天了。但其实真正用过千卡就会发现也就那么几个点。于是想写一篇文章简单讲讲。

本文将包括三个部分:首先我们将讨论千卡训练的难题,以及应该在什么时候使用千卡训练;接着,我们将讨论如何在一千张卡上开始训练,如何让他达到近乎线性的性能提升;最后我们将展开讨论一些千卡训练当中仍然悬而未决(至少对于开源社区来说)的问题。

为什么千卡训练是困难的?

其实那篇回答在这部分说的没错。千卡训练和八卡训练的区别是—显卡多了一百多倍。

这意味着什么呢?

  1. 通信时间增加
  2. 故障概率增加

这俩问题都很好理解。

时间上,PyTorch 内部支持 NCCL / Gloo / MPI 三个通信后端(请务必使用 NCCL。其中训网络最常用的 AllReduce 操作会会根据具体硬件配置走 Ring AllReduce 和 Tree AllReduce。Ring 的时间复杂度是 \(O(p n)\),Tree 的时间复杂度是 \(O(\log p n)\)。就算是理论上 128 节点也比单节点慢至少七倍,实践当中跨节点通信要远比单节点慢得多。

故障上,一个节点出问题的概率是 \(p\),128 个节点就是 \(1-(1-p^{128})\)。也就是说如果一个操作在一个训练当中的出错概率是 1%,那么在 128 节点当中的出错概率就是 72.37%。

此外,随着规模的增大,许多问题都会变得难以忍受。比如数据增强要花 0.1s,一亿条数据就是 278 个小时(当然这只是胡拆的一个数字,实际有各种机制所以不会有这么大影响。

因此,钱多烧手并不是使用千卡训练的理由。闲得蛋疼可能是,但你得多蛋疼才能想出这么折磨自己的 idea?

因此,千卡训练解决的问题是大模型&大数据问题。如果你的训练时间没有超过 8192 GPU 日,那么你绝对不需要一千张显卡。

看到这里,绝大多数人已经可以关掉这篇文章了。除非你的模型和数据都以 B(十亿)来作为计量单位。当然如果你正在厕所里手机没电想看点儿东西解闷儿的话(虽然我很怀疑是否会有人把他打出来……那么可以继续往下看

如何使用一千张卡训练?

如何提高计算效率?

这件事情其实是一个 case by case 的事情。因为通信、计算速度啥的受硬件影响更多。同样是 A100 集群,我全 DGX 节点,每一张 A100 都是 SXM 接口并配一块儿专属的 IB 网卡。你一个小破普惠服务器插 8 张 PCI-E A100,IB 卡一个节点只给一张。那咱俩遇到的问题就完全不是一个问题。

因此,要讨论如何提高训练效率、减少训练耗时,我们首先要了解训练耗时在哪里。那么,一个训练步的耗时在哪里呢?需要谨记,没有 profile 的优化是没有意义的。

你可能会说,forward backward sync。很好,这说明你了解 PyTorch 的基本流程。不过现实当中要复杂得多。

  1. dataset 读取数据,构建输出
  2. dataloader collate 数据,进行数据预处理
  3. 模型 forward 计算输出
  4. loss compute
  5. 模型 backward 计算梯度
  6. 模型 sync 梯度
  7. 优化器 step 更新权重
  8. 打印 log

当然这是可以无限细分下去的,但一般这些就够了。需要注意的是,除了 4-7 的耗时是真耗时,其他都需要通过异步操作来盖掉。这也是我们的优化目标。

异步执行在 PyTorch 的 dataloader、CUDA 和分布式当中都存在。前者可以通过设置 num_workers 和 prefetch_count 为 0 来关闭,后两者可以通过 cuda.synchornize 和 dist.barrier 来执行手动同步。在 profile 时,我们需要首先需要测整个 step 的时长。然后再在每次测量前执行手动同步来计算每个部分的时长。如果前者的总耗时等于后者 4-7 的耗时之和,那么通常不需要执行任何操作。但这种情况在千卡操作中几乎不可能发生。

第 6 步通信往往需要耗费大量时间。因此,我们还需要进一步优化通信。

以下内容是对PyTorch Distributed的概括,有感兴趣的同学建议通读并背诵全文。

计算-通信重叠

在 PyTorch 当中,梯度的通信和反向传播是交叠进行的。也就是说,每完成一层的梯度计算,都会立即触发当前层的同步。实现起来也很简单,每个进程在完成自己第 k 层的梯度计算后都会触发一个钩子来给计数器+1s。当计数器达到进程数时开火进行梯度通信。有很多同学在计算梯度过程中遇到过 RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. 错误,这就是因为有的模块没有参与计算 loss,导致梯度同步卡住了。需要注意,当 find_unused_parameters=True 时,PyTorch 分布式使用 nn.Module.__init__ 当中定义子模块的反向顺序来作为梯度桶的构建顺序。因此,确保模块定义和调用的顺序一致是一个良好的实践。

梯度合桶

尽管理论上来说,同步发生的越及时,重合度越高,性能越好。但实际上每次发起通信都是有上头的。因此,现实当中梯度同步并不是越多越好越快越好。为此,PyTorch 引入了梯度合桶机制,通过把多个 Tensor 装在一个桶里再通信桶来减少通信次数从而减少总耗时。合桶的 bucket_cap_mb 默认是 25MiB,这对于绝大多数模型来说都是太小的。目前已经有提升这个默认值的特性需求,但是这个还是调一下更好。

梯度累加

当你做完所有操作之后,惊喜的发现 TMD 怎么同步时间还是单节点的好几倍。这其实是正常情况……实际上超过 256 卡的训练想要把通信盖掉就是一件不可能的事情。你说老师我看 FB 论文说他们 256 卡就是线性提升啊…那这里不得不提的一个策略就是梯度累加了。梯度累加会执行 k 次 forward+backward 之后再执行优化器步进。这有很多好处,首先对于大模型 batch size 通常不能开多大,梯度累加可以提升等效 batch size。其次累加期间的 backward 不需要通信梯度,加快了训练速度。

少即是快

Python 是一种很慢的语言。当然你说 JIT trace+torch.compile 有提升我也不反对,但对于最高效率来说,只有必须要存在的代码和不存在的代码两种。

抱抱脸的 Transformers 就是一个反例。两个子模块就能写完的 TransformerLayer 他们硬是能写出来一堆…偏偏他们还信奉 Single Model File Policy……我寻思你这完全不考虑继承的封这么多层是要搞鸡毛啊?正例反而是 PyTorch……(笑死,我竟然会夸脸书代码写得好。具体来说就是 nn.functional 当中的各种实现。你会发现他们第一行往往是 handle_torch_func。熟悉 Python 装饰器的小伙汁通常要问了,为啥这里不用个装饰器统一一下?因为装饰器会引入额外的函数调用,额外的函数调用就是额外的上头。

因此,如果你想确保最高的效率,写一个简单的训练代码和模型代码非常重要。毕竟,1%的效率提升,节省的可能是数百个 GPU 日。

如何平稳训练

这一段当中中咱们只讨论你能控制的问题。

捕捉不致命的异常

故障率高的问题其实很好解决。在训练当中,大部分异常都是非致命异常,接住他们就好了。我之前写过一个装饰器,catch,它的作用就是接住异常,然后调回调函数(默认当然就是把错误打印到 log 里)。所有你需要做的只是使用它来装饰所有非 fatal 的操作。

在实际应用当中,我们遇到的最常见的问题是存 ckpt 写满了磁盘(不准笑,从商汤到上海 AI Lab,这个问题在哪儿都日常出现。咱也不知道为啥肯买那么多显卡但不肯多插点儿硬盘,咱也不敢问)。接住所有保存操作,如果你有闲心可以在回调里删一下之前的 ckpt。没闲心的话…大不了重训一次嘛(逃。第二常见的问题,你猜对了……存 log 写满了硬盘……所以所有 logging 操作也都是要 catch 的。这就是为啥我都用 tmux 然后开很长的缓存窗口,总是能抢救一些 log 出来的。

咳咳,说点儿正经的。任何联网操作都是需要 catch 的,常见的联网操作主要包括从 ceph 读取数据和…写 log 到远程(逃。其他就没啥了吧,我见过有大哥尝试恢复 OOM 的,但效果似乎不是很好,至少我自己没用过。简单来说,唯一不应捕捉的错误是集群炸了。

那有的大兄弟就说了,集群没爆炸,但是有两张卡突然掉了咋办。这个咱第三部分再讨论。

管好模型的输出

模型训着训着发散了几乎是每个训大模型的人都会遇到的问题。输出和 loss 只要有 nan 果断丢掉。梯度先 clip by value 再 clip by norm 都是常规操作。哦对了,还有初始化……关于大模型收敛性的论文有一堆,此处不再赘述。

比更大,还更大,再更大

弹性训练

实际上当你的训练超过 2048 个 GPU 日时,在整个训练过程当中发生单个 GPU 甚至单个节点下线是再正常不过的事情了。

PyTorch 在 1.10 就引入了 torchelastic 弹性训练机制。这个东西,用过的都骂娘。等下,让我先骂一遍。呸。ok 咱们继续吧。

我印象当中在微软的最后一轮面试当中被问到了这个问题:如何设计一个弹性分布式系统。

我的回答很教科书。每 k 分钟,系统会做一次 AllGather 来统计存活进程数,然后选举出一个主进程。主进程会计算好每个进程的 rank 和 local rank 然后广播给各个进程。所有进程每次前向传播开始时向主进程发送一个心跳包来汇报状态。主进程会根据心跳包来确定这一个 step 参与同步的机器有多少。

但很可惜,2024 年了。还是没人去写。他妈的。

层次化梯度同步

我一直认为梯度同步不应该以 GPU/进程为单位。而应该分为大同步(节点间同步)和小同步(节点内同步)。小同步可以更高频的进行,大同步则可以更慢的执行。这样不仅能提高实际的梯度同步频率,降低同步总耗时,并且还能天然的去结合小 batch 和大 batch 训练的优点—节点内小 batch 关注个体,节点间大 batch 关注整体。

延伸阅读

PyTorch DDP 设计笔记

PyTorch 微调菜谱

分析与优化使用 PyTorch 训练机器学习模型

使用 Nsight Systems 来分析 GPU 负载

DLProf

NVProf

NCCL AllReduce 设计

Spike No More

Understanding the difficulty of training deep feedforward neural networks

不知不觉间,飞机开始进近,这篇文章也该走到终点了。本篇文章更多是基于香草 PyTorch 的概念性文章,实际使用当中大家更多还是用 DeepSpeed 或者 Megatron-LM 这样专门为大规模训练而设计的框架。训练细节也要多一些。但本文的核心思想是,用千卡训模型,真的没多难。只是大多数时候没有人有这个需要罢了。


甲辰年三月二十

于中国南海