文章预览
来源:https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323 。下面文章包含2个主题,第一个是FSDP2中开启Float8 All-Gather的Discussion的翻译,第二个是TorchAO中的Float8实现速览的翻译。这篇文档主要介绍了在FSDP2中启用float8 all-gather功能的实现和优化。通过在128个H100 GPU上预训练Llama3-70B模型的验证,相比bfloat16获得了1.50倍的性能提升,其中20%来自float8 all-gather,80%来自float8计算。文档详细描述了Float8训练的两个关键组件:通过torch._scaled_mm实现的float8计算和能节省50%带宽的float8通信。在优化策略方面,通过Float8计算+Bfloat16 All-Gather获得1.40倍加速,再通过带独立AMAX All-Reduce的Float8 All-Gather和组合AMAX AllReduce分别获得0.02倍和0.08倍的额外提升,同时还优化了NCCL和Float8计算之间的SM资源竞争。文档还提供了完整的代码示例,展示了如何将nn.Linear
………………………………