Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

四块GPU即可训练BigGAN:「官方版」PyTorch实现出炉

作为「史上最强 GAN 图像生成器」,BigGAN 自去年 9 月推出以来就成为了 AI 领域最热词。其生成图像的目标和背景都高度逼真、边界自然,简直可以说是在「创造新物种」。然而 BigGAN 训练时需要的超高算力(128-512 个谷歌 TPU v3 核心)却让很多想要参与制图狂欢的开发者望而却步。

今日,BigGAN 论文的第一作者、来自英国 Heriot-Watt 大学的 Andrew Brock 发布了 BigGAN 的 PyTorch 版实现。最令人高兴的是:这一次训练模型的算力要求被降低到 4 到 8 块 GPU 了!

项目链接:https://github.com/ajbrock/BigGAN-PyTorch

该项目一出即引发了人们的广泛关注,有的人表示不敢相信,也有人哭晕在 Colab。
Brock 本次放出的 BigGAN 实现包含训练、测试、采样脚本以及完整的预训练检查点(生成器、判别器和优化器),以便你可以在自己的数据上进行微调或者从零开始训练模型。

作者表示,这些代码制作时间很长,从一开始就被设计成可操控、可扩展的基础,以方便未来的研究。作者花了很多心思考虑在什么地方具体使用什么抽象,以确保它们有效但又易于理解或改变。

这一工作是 Andrew Brock 与 MIT 的 Alex Andonian 一起完成的。

BigGAN 的 PyTorch 实现

这是由论文原作者正式发布的「非官方」BigGAN PyTorch 实现。

该 repo 包含用 4-8 个 GPU 训练 BigGAN 的代码。

如何使用

你需要用到:

  • 1.0.1 版本的 PyTorch

  • tqdm、numpy、scipy 和 h5py

  • ImageNet 训练集

首先,你可以准备目标数据集的预处理 HDF5 版本,以便更快地输入/输出(可选)。在此之后(不管是否如此做了),你需要计算 FID 所需的 Inception moment。这些都可以通过修改并运行以下代码来完成:

sh scripts/utils/prepare_data.sh

默认情况下,假设你的 ImageNet 训练集已经下载至此目录的根文件夹 data 中,然后以 128x128 的像素分辨率准备缓存的 HDF5。

脚本文件夹中有多个 bash 脚本,此类脚本可以用不同的批量大小训练 BigGAN。这段代码假设你无法访问完整的 TPU pod,然后通过梯度累积(将多个小批量上的梯度平均化,然后仅在 N 次累积后采取优化步骤)表示相应的 mega-batches。默认情况下,launch_BigGAN_bs256x8.sh 脚本训练批量大小为 256 且具备 8 次梯度累积的完整 BigGAN 模型,其总的批量大小为 2048。在 8xV100 上进行全精度训练(无张量核),这个脚本需要 15天训练到 15 万次迭代。

你需要先确定你的设置能够支持的最大批量。这里提供的预训练模型是在 8xV100(每个有 16GB VRAM)上训练的,8xV100 能支持比默认使用的 BS256 略大的批量大小。一旦确定了这一点,你应该修改脚本,使批大小乘以梯度累积的数量等同于你期望的总批量大小(BigGAN 默认的总批量大小是 2048)。

注意,这个脚本使用参数 --load_in_mem,该参数会将整个 I128.hdf5(约 64GB)文件加载至 RAM 中,以便更快地加载数据。如果你没有足够的 RAM 来支持这个(可能需要 96GB 以上),删除这个参数

度量和采样

在训练过程中,该脚本将输出包含训练度量和测试度量的日志,并保存模型权重/优化器参数的多个副本(2 个最新的和 5 个得分最高的),还会在每次保存权重时产生样本和插值。日志文件夹包含处理这些日志及使用 MATLAB 绘制结果的脚本。

训练结束后,你可以使用 sample.py 生成额外的样本和插值,用不同的截断值、批大小、standing stat 累积次数等进行测试。示例参考 sample_BigGAN_bs256x8.sh 脚本。

默认情况下,所有内容都会保存至 weights/samples/logs/data 文件夹中,这些文件夹应与该 repo 在同一文件夹中。你可以使用 --base_root 参数将这些文件夹指向不同的根目录,或者使用对应的参数(如 --logs_root)为每个文件夹选择特定的位置。

该 repo 还包含运行 BigGAN-deep 的脚本,但作者尚未使用它们来完整地训练模型,所以可将其视为未经测试。另外,该 repo 包含在 CIFAR 上运行模型的脚本,以及在 ImageNet 上运行 SA-GAN(带有EMA)和 SN-GAN 的脚本。SA-GAN 代码假设你有 4xTitanX(或具备同等 RAM 的 GPU),并使用 128 的批量大小和 2 个梯度累积来训练。

关于 Inception 度量的重要提示

该 repo 使用 PyTorch 内置 inception 网络来计算 IS 和 FID 分数。这些分数与使用官方 TF inception 代码得到的不同,且仅用于监控目的。使用 --sample_npz 参数在模型上运行 sample.py,然后运行 inception_tf13 来计算真实的 TensorFlow IS。注意:你需要安装 TensorFlow 1.3 或更早版本,因为 TF1.4+ 会破坏原始 IS 代码。

预训练模型

该 repo 包含两个预训练模型检查点(具备 G、D、G 的 EMA copy、优化器和 state dict):

使用 Places-365 数据集预训练模型也将很快开源。

该 repo 还包含将原始 TFHub BigGAN Generator 权重迁移到 PyTorch 的脚本。详见 TFHub 文件夹。

使用自己的数据集或新的训练函数对模型进行微调


如果你想继续被中断的训练或者微调预训练模型,运行同样的启动脚本,不过这次需要添加 —resume 参数。实验名称是从配置中自动生成的,但是你可以使用 —experiment_name 参数对其进行重写(例如你想使用修改后的优化器设置来微调模型)。

要想使用自己的数据集,你需要将其添加到 datasets.py,并修改 utils.py 中的 convenience dicts (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict),从而为自己的数据集准备适合的元数据。在 prepare_data.sh 中重复该过程(可选择性地生成 HDF5 preprocessed copy,然后计算 FID 所需的 Inception moment。

默认情况下,该训练脚本将以 Inception Score 为衡量标准选出  top 5 最优检查点并保存。对于 ImageNet 以外的数据集,模型的 Inception Score 可能不是很好的质量度量标准,因此你可以使用 which_best FID 来代替 Inception Score。

要想使用自己的训练函数(如训练 BigVAE),你可以修改  train_fns.GAN_training_function,或者将新的训练函数添加到 if config['which_train_fn'] == 'GAN' 之后(train.py 中的行)。

亮点

  • 该 repo 提供完整的训练和度量日志,以供参考。作者发现,重新实现一篇论文时最困难的事情之一是检查日志在训练早期是否排列整齐,尤其是训练需要花费数周时间时。希望这些工作有利于未来的研究。

  • 该 repo 用了加速的 FID 计算:初始 scipy 版本需要 10 多分钟来计算矩阵 sqrt,而该版本使用加速的 PyTorch 版本,能在 1 秒内完成计算。

  • 该 repo 用了一种加速型、低内存消耗的正交寄存器(ortho reg)实现。

  • 默认情况下,该 repo 只计算最大奇异值(谱范数),但该代码通过 —num_G_SVs 参数支持更多 SV 的计算。

这段代码与原始 BigGAN 的关键区别

  • 不同于BigGAN的G_lr=5e-5, D_lr=2e-5, num_D_steps=2),该repo使用出自SA-GAN (G_lr=1e-4, D_lr=4e-4, num_D_steps=1的优化器设置。虽然性能稍差,但这是该repo减少训练时间所采取的第一个措施。

  • 默认情况下,该repo不使用Cross-Replica BatchNorm(AKA Synced BatchNorm)。该repo尝试的两种变体(一种是常规简单的变体,一种是该repo中的变体)与内置BatchNorm具有略微不同的梯度(尽管采用相同的正推计算法),这对于弱化训练似乎足够了。

  • 梯度累积意味着该repo更频繁地更新SV估值和8倍BN统计。这意味着BN统计更有可能是固定统计,同时奇异值估算也更准确。基于此,默认情况下,该repo在测试模式下通过G来度量(在论文中使用BatchNorm动态统计,而不计算固定统计)。该repo依然支持固定统计(参见sample.sh脚本)。这也可能导致早期积累的梯度过时,但在实践中这不再是一个问题。

  • 当前提供的预训练模型没有通过正交规范化训练。缺少正交寄存器的训练增加了模型摆脱截断影响的概率,但看起来这一特定模型中奖了。无论如何,该repo提供两种高度优化(速度快且内存消耗最小)的正交寄存器实现,从而直接计算正交寄存器梯度。


参考文章:

工程BigGANPyTorchGAN
5
相关数据
范数技术

范数(norm),是具有“长度”概念的函数。在线性代数、泛函分析及相关的数学领域,是一个函数,其为向量空间内的所有向量赋予非零的正长度或大小。半范数反而可以为非零的向量赋予零长度。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

TensorFlow技术

TensorFlow是一个开源软件库,用于各种感知和语言理解任务的机器学习。目前被50个团队用于研究和生产许多Google商业产品,如语音识别、Gmail、Google 相册和搜索,其中许多产品曾使用过其前任软件DistBelief。

张量技术

张量是一个可用来表示在一些矢量、标量和其他张量之间的线性关系的多线性函数,这些线性关系的基本例子有内积、外积、线性映射以及笛卡儿积。其坐标在 维空间内,有 个分量的一种量,其中每个分量都是坐标的函数,而在坐标变换时,这些分量也依照某些规则作线性变换。称为该张量的秩或阶(与矩阵的秩和阶均无关系)。 在数学里,张量是一种几何实体,或者说广义上的“数量”。张量概念包括标量、矢量和线性算子。张量可以用坐标系统来表达,记作标量的数组,但它是定义为“不依赖于参照系的选择的”。张量在物理和工程学中很重要。例如在扩散张量成像中,表达器官对于水的在各个方向的微分透性的张量可以用来产生大脑的扫描图。工程上最重要的例子可能就是应力张量和应变张量了,它们都是二阶张量,对于一般线性材料他们之间的关系由一个四阶弹性张量来决定。

插值技术

数学的数值分析领域中,内插或称插值(英语:interpolation)是一种通过已知的、离散的数据点,在范围内推求新数据点的过程或方法。求解科学和工程的问题时,通常有许多数据点借由采样、实验等方法获得,这些数据可能代表了有限个数值函数,其中自变量的值。而根据这些数据,我们往往希望得到一个连续的函数(也就是曲线);或者更密集的离散方程与已知数据互相吻合,这个过程叫做拟合。

规范化技术

规范化:将属性数据按比例缩放,使之落入一个小的特定区间,如-1.0 到1.0 或0.0 到1.0。 通过将属性数据按比例缩放,使之落入一个小的特定区间,如0.0到1.0,对属性规范化。对于距离度量分类算法,如涉及神经网络或诸如最临近分类和聚类的分类算法,规范化特别有用。如果使用神经网络后向传播算法进行分类挖掘,对于训练样本属性输入值规范化将有助于加快学习阶段的速度。对于基于距离的方法,规范化可以帮助防止具有较大初始值域的属性与具有较小初始值域的属相相比,权重过大。有许多数据规范化的方法,包括最小-最大规范化、z-score规范化和按小数定标规范化。

图像生成技术

图像生成(合成)是从现有数据集生成新图像的任务。

优化器技术

优化器基类提供了计算梯度loss的方法,并可以将梯度应用于变量。优化器里包含了实现了经典的优化算法,如梯度下降和Adagrad。 优化器是提供了一个可以使用各种优化算法的接口,可以让用户直接调用一些经典的优化算法,如梯度下降法等等。优化器(optimizers)类的基类。这个类定义了在训练模型的时候添加一个操作的API。用户基本上不会直接使用这个类,但是你会用到他的子类比如GradientDescentOptimizer, AdagradOptimizer, MomentumOptimizer(tensorflow下的优化器包)等等这些算法。

推荐文章
暂无评论
暂无评论~