今天看啥  ›  专栏  ›  小白玩转Python

在 CIFAR10 数据集上训练 Vision Transformer (ViT)

小白玩转Python  · 公众号  ·  · 2024-11-09 20:46
    

文章预览

点击下方 卡片 ,关注“ 小白玩转Python ”公众号 在这篇简短的文章中,我将构建一个简单的 ViT 并将其训练在 CIFAR 数据集上。 训练循环 我们从训练 CIFAR 数据集上的模型的样板代码开始。我们选择批量大小为64,以在性能和 GPU 资源之间取得平衡。我们将使用 Adam 优化器,并将学习率设置为0.001。与 CNN 相比,ViT 收敛得更慢,所以我们可能需要更多的训练周期。此外,根据我的经验,ViT 对超参数很敏感。一些超参数会使模型崩溃并迅速达到零梯度,模型的参数将不再更新。因此,您必须测试与模型大小和形状本身以及训练超参数相关的不同超参数。 transform_train = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), ………………………………

原文地址:访问原文地址
快照地址: 访问文章快照
总结与预览地址:访问总结与预览