作为「史上最强 GAN 图像生成器」,BigGAN 自去年 9 月推出以来就成为了 AI 领域最热词。其生成图像的目标和背景都高度逼真、边界自然,简直可以说是在「创造新物种」。然而 BigGAN 训练时需要的超高算力(128-512 个谷歌 TPU v3 核心)却让很多想要参与制图狂欢的开发者望而却步。
今日,BigGAN 论文的第一作者、来自英国 Heriot-Watt 大学的 Andrew Brock 发布了 BigGAN 的 PyTorch 版实现。最令人高兴的是:这一次训练模型的算力要求被降低到 4 到 8 块 GPU 了!
项目链接:https://github.com/ajbrock/BigGAN-PyTorch
该项目一出即引发了人们的广泛关注,有的人表示不敢相信,也有人哭晕在 Colab。
Brock 本次放出的 BigGAN 实现包含训练、测试、采样脚本以及完整的预训练检查点(生成器、判别器和优化器),以便你可以在自己的数据上进行微调或者从零开始训练模型。
作者表示,这些代码制作时间很长,从一开始就被设计成可操控、可扩展的基础,以方便未来的研究。作者花了很多心思考虑在什么地方具体使用什么抽象,以确保它们有效但又易于理解或改变。
这一工作是 Andrew Brock 与 MIT 的 Alex Andonian 一起完成的。
BigGAN 的 PyTorch 实现
这是由论文原作者正式发布的「非官方」BigGAN PyTorch 实现。
该 repo 包含用 4-8 个 GPU 训练 BigGAN 的代码。
如何使用
你需要用到:
1.0.1 版本的 PyTorch
tqdm、numpy、scipy 和 h5py
ImageNet 训练集
首先,你可以准备目标数据集的预处理 HDF5 版本,以便更快地输入/输出(可选)。在此之后(不管是否如此做了),你需要计算 FID 所需的 Inception moment。这些都可以通过修改并运行以下代码来完成:
sh scripts/utils/prepare_data.sh
默认情况下,假设你的 ImageNet 训练集已经下载至此目录的根文件夹 data
中,然后以 128x128 的像素分辨率准备缓存的 HDF5。
脚本文件夹中有多个 bash 脚本,此类脚本可以用不同的批量大小训练 BigGAN。这段代码假设你无法访问完整的 TPU pod,然后通过梯度累积(将多个小批量上的梯度平均化,然后仅在 N 次累积后采取优化步骤)表示相应的 mega-batches。默认情况下,launch_BigGAN_bs256x8.sh
脚本训练批量大小为 256 且具备 8 次梯度累积的完整 BigGAN 模型,其总的批量大小为 2048。在 8xV100 上进行全精度训练(无张量核),这个脚本需要 15天训练到 15 万次迭代。
你需要先确定你的设置能够支持的最大批量。这里提供的预训练模型是在 8xV100(每个有 16GB VRAM)上训练的,8xV100 能支持比默认使用的 BS256 略大的批量大小。一旦确定了这一点,你应该修改脚本,使批大小乘以梯度累积的数量等同于你期望的总批量大小(BigGAN 默认的总批量大小是 2048)。
注意,这个脚本使用参数 --load_in_mem
,该参数会将整个 I128.hdf5(约 64GB)文件加载至 RAM 中,以便更快地加载数据。如果你没有足够的 RAM 来支持这个(可能需要 96GB 以上),删除这个参数。
度量和采样
在训练过程中,该脚本将输出包含训练度量和测试度量的日志,并保存模型权重/优化器参数的多个副本(2 个最新的和 5 个得分最高的),还会在每次保存权重时产生样本和插值。日志文件夹包含处理这些日志及使用 MATLAB 绘制结果的脚本。
训练结束后,你可以使用 sample.py
生成额外的样本和插值,用不同的截断值、批大小、standing stat 累积次数等进行测试。示例参考 sample_BigGAN_bs256x8.sh
脚本。
默认情况下,所有内容都会保存至 weights/samples/logs/data 文件夹中,这些文件夹应与该 repo 在同一文件夹中。你可以使用 --base_root
参数将这些文件夹指向不同的根目录,或者使用对应的参数(如 --logs_root)为
每个文件夹选择特定的位置。
该 repo 还包含运行 BigGAN-deep 的脚本,但作者尚未使用它们来完整地训练模型,所以可将其视为未经测试。另外,该 repo 包含在 CIFAR 上运行模型的脚本,以及在 ImageNet 上运行 SA-GAN(带有EMA)和 SN-GAN 的脚本。SA-GAN 代码假设你有 4xTitanX(或具备同等 RAM 的 GPU),并使用 128 的批量大小和 2 个梯度累积来训练。
关于 Inception 度量的重要提示
该 repo 使用 PyTorch 内置 inception 网络来计算 IS 和 FID 分数。这些分数与使用官方 TF inception 代码得到的不同,且仅用于监控目的。使用 --sample_npz
参数在模型上运行 sample.py,然后运行 inception_tf13 来计算真实的 TensorFlow IS。注意:你需要安装 TensorFlow 1.3 或更早版本,因为 TF1.4+ 会破坏原始 IS 代码。
预训练模型
该 repo 包含两个预训练模型检查点(具备 G、D、G 的 EMA copy、优化器和 state dict):
主要检查点是在 128x128 ImageNet 图像上训练的 BigGAN,该模型使用 BS256 和 8 次梯度累积,并在崩溃前实现,其 TF Inception Score 为 97.35 +/- 1.79,详见:https://drive.google.com/open?id=1nAle7FCVFZdix2—ks0r5JBkFnKw8ctW。
第一个模型的更早检查点 (100k G iters)性能优秀且在崩溃前实现,可能比较容易微调,详见:https://drive.google.com/open?id=1dmZrcVJUAWkPBGza_XgswSuT-UODXZcO。
使用 Places-365 数据集预训练模型也将很快开源。
该 repo 还包含将原始 TFHub BigGAN Generator 权重迁移到 PyTorch 的脚本。详见 TFHub 文件夹。
使用自己的数据集或新的训练函数对模型进行微调
如果你想继续被中断的训练或者微调预训练模型,运行同样的启动脚本,不过这次需要添加 —resume 参数。实验名称是从配置中自动生成的,但是你可以使用 —experiment_name 参数对其进行重写(例如你想使用修改后的优化器设置来微调模型)。
要想使用自己的数据集,你需要将其添加到 datasets.py,并修改 utils.py 中的 convenience dicts (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict),从而为自己的数据集准备适合的元数据。在 prepare_data.sh 中重复该过程(可选择性地生成 HDF5 preprocessed copy,然后计算 FID 所需的 Inception moment。
默认情况下,该训练脚本将以 Inception Score 为衡量标准选出 top 5 最优检查点并保存。对于 ImageNet 以外的数据集,模型的 Inception Score 可能不是很好的质量度量标准,因此你可以使用 which_best FID 来代替 Inception Score。
要想使用自己的训练函数(如训练 BigVAE),你可以修改 train_fns.GAN_training_function,或者将新的训练函数添加到 if config['which_train_fn'] == 'GAN' 之后(
train.py 中的行)。
亮点
该 repo 提供完整的训练和度量日志,以供参考。作者发现,重新实现一篇论文时最困难的事情之一是检查日志在训练早期是否排列整齐,尤其是训练需要花费数周时间时。希望这些工作有利于未来的研究。
该 repo 用了加速的 FID 计算:初始 scipy 版本需要 10 多分钟来计算矩阵 sqrt,而该版本使用加速的 PyTorch 版本,能在 1 秒内完成计算。
该 repo 用了一种加速型、低内存消耗的正交寄存器(ortho reg)实现。
默认情况下,该 repo 只计算最大奇异值(谱范数),但该代码通过 —num_G_SVs 参数支持更多 SV 的计算。
这段代码与原始 BigGAN 的关键区别
不同于BigGAN的G_lr=5e-5, D_lr=2e-5, num_D_steps=2),该repo使用出自SA-GAN (G_lr=1e-4, D_lr=4e-4, num_D_steps=1的优化器设置。虽然性能稍差,但这是该repo减少训练时间所采取的第一个措施。
默认情况下,该repo不使用Cross-Replica BatchNorm(AKA Synced BatchNorm)。该repo尝试的两种变体(一种是常规简单的变体,一种是该repo中的变体)与内置BatchNorm具有略微不同的梯度(尽管采用相同的正推计算法),这对于弱化训练似乎足够了。
梯度累积意味着该repo更频繁地更新SV估值和8倍BN统计。这意味着BN统计更有可能是固定统计,同时奇异值估算也更准确。基于此,默认情况下,该repo在测试模式下通过G来度量(在论文中使用BatchNorm动态统计,而不计算固定统计)。该repo依然支持固定统计(参见sample.sh脚本)。这也可能导致早期积累的梯度过时,但在实践中这不再是一个问题。
当前提供的预训练模型没有通过正交规范化训练。缺少正交寄存器的训练增加了模型摆脱截断影响的概率,但看起来这一特定模型中奖了。无论如何,该repo提供两种高度优化(速度快且内存消耗最小)的正交寄存器实现,从而直接计算正交寄存器梯度。
参考文章: