Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

机器之心编辑部编译

不用卷积,也能生成清晰图像,华人博士生首次尝试用两个Transformer构建一个GAN

「attention is really becoming『all you need』.」

最近,CV 研究者对 transformer 产生了极大的兴趣并取得了不少突破。这表明,transformer 有可能成为计算机视觉任务(如分类、检测和分割)的强大通用模型。

我们都很好奇:在计算机视觉领域,transformer 还能走多远?对于更加困难的视觉任务,比如生成对抗网络 (GAN),transformer 表现又如何?

在这种好奇心的驱使下,德州大学奥斯汀分校的 Yifan Jiang、Zhangyang Wang,IBM Research 的 Shiyu Chang 等研究者进行了第一次试验性研究,构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。与其它基于 transformer 的视觉模型相比,仅使用 transformer 构建 GAN 似乎更具挑战性,这是因为与分类等任务相比,真实图像生成的门槛更高,而且 GAN 训练本身具有较高的不稳定性。

  • 论文链接:https://arxiv.org/pdf/2102.07074.pdf

  • 代码链接:https://github.com/VITA-Group/TransGAN

从结构上来看,TransGAN 包括两个部分:一个是内存友好的基于 transformer 的生成器,该生成器可以逐步提高特征分辨率,同时降低嵌入维数;另一个是基于 transformer 的 patch 级判别器。

研究者还发现,TransGAN 显著受益于数据增强(超过标准的 GAN)、生成器的多任务协同训练策略和强调自然图像邻域平滑的局部初始化自注意力。这些发现表明,TransGAN 可以有效地扩展至更大的模型和具有更高分辨率的图像数据集。

实验结果表明,与当前基于卷积骨干的 SOTA GAN 相比,表现最佳的 TransGAN 实现了极具竞争力的性能。具体来说,TransGAN 在 STL-10 上的 IS 评分为 10.10,FID 为 25.32,实现了新的 SOTA。

该研究表明,对于卷积骨干以及许多专用模块的依赖可能不是 GAN 所必需的,纯 transformer 有足够的能力生成图像。

在该论文的相关讨论中,有读者调侃道,「attention is really becoming『all you need』.」

不过,也有部分研究者表达了自己的担忧:在 transformer 席卷整个社区的大背景下,势单力薄的小实验室要怎么活下去?

如果 transformer 真的成为社区「刚需」,如何提升这类架构的计算效率将成为一个棘手的研究问题。

基于纯 Transformer 的 GAN

作为基础块的 Transformer 编码器

研究者选择将 Transformer 编码器(Vaswani 等人,2017)作为基础块,并尽量进行最小程度的改变。编码器由两个部件组成,第一个部件由一个多头自注意力模块构造而成,第二个部件是具有 GELU 非线性的前馈 MLP(multiple-layer perceptron,多层感知器)。此外,研究者在两个部件之前均应用了层归一化(Ba 等人,2016)。两个部件也都使用了残差连接。

内存友好的生成器

NLP 中的 Transformer 将每个词作为输入(Devlin 等人,2018)。但是,如果以类似的方法通过堆叠 Transformer 编码器来逐像素地生成图像,则低分辨率图像(如 32×32)也可能导致长序列(1024)以及更高昂的自注意力开销。

所以,为了避免过高的开销,研究者受到了基于 CNN 的 GAN 中常见设计理念的启发,在多个阶段迭代地提升分辨率(Denton 等人,2015;Karras 等人,2017)。他们的策略是逐步增加输入序列,并降低嵌入维数

如下图 1 左所示,研究者提出了包含多个阶段的内存友好、基于 Transformer 的生成器:

每个阶段堆叠了数个编码器块(默认为 5、2 和 2)。通过分段式设计,研究者逐步增加特征图分辨率,直到其达到目标分辨率 H_T×W_T。具体来说,该生成器以随机噪声作为其输入,并通过一个 MLP 将随机噪声传递给长度为 H×W×C 的向量。该向量又变形为分辨率为 H×W 的特征图(默认 H=W=8),每个点都是 C 维嵌入。然后,该特征图被视为长度为 64 的 C 维 token 序列,并与可学得的位置编码相结合。

与 BERT(Devlin 等人,2018)类似,该研究提出的 Transformer 编码器以嵌入 token 作为输入,并递归地计算每个 token 之间的匹配。为了合成分辨率更高的图像,研究者在每个阶段之后插入了一个由 reshaping 和 pixelshuffle 模块组成的上采样模块。

具体操作上,上采样模块首先将 1D 序列的 token 嵌入变形为 2D 特征图,然后采用 pixelshuffle 模块对 2D 特征图的分辨率进行上采样处理,并下采样嵌入维数,最终得到输出。然后,2D 特征图 X’_0 再次变形为嵌入 token 的 1D 序列,其中 token 数为 4HW,嵌入维数为 C/4。所以,在每个阶段,分辨率(H, W)提升到两倍,同时嵌入维数 C 减少至输入的四分之一。这一权衡(trade-off)策略缓和了内存和计算量需求的激增。
研究者在多个阶段重复上述流程,直到分辨率达到(H_T , W_T )。然后,他们将嵌入维数投影到 3,并得到 RGB 图像

用于判别器的 tokenized 输入

与那些需要准确合成每个像素的生成器不同,该研究提出的判别器只需要分辨真假图像即可。这使得研究者可以在语义上将输入图像 tokenize 为更粗糙的 patch level(Dosovitskiy 等人,2020)。

如上图 1 右所示,判别器以图像的 patch 作为输入。研究者将输入图像分解为 8 × 8 个 patch,其中每个 patch 可被视为一个「词」。然后,8 × 8 个 patch 通过一个线性 flatten 层转化为 token 嵌入的 1D 序列,其中 token 数 N = 8 × 8 = 64,嵌入维数为 C。再之后,研究者在 1D 序列的开头添加了可学得位置编码和一个 [cls] token。在通过 Transformer 编码器后,分类 head 只使用 [cls] token 来输出真假预测。
实验

CIFAR-10 上的结果

研究者在 CIFAR-10 数据集上对比了 TransGAN 和近来基于卷积的 GAN 的研究,结果如下表 5 所示:

如上表 5 所示,TransGAN 优于 AutoGAN (Gong 等人,2019) ,在 IS 评分方面也优于许多竞争者,如 SN-GAN (Miyato 等人, 2018)、improving MMDGAN (Wang 等人,2018a)、MGAN (Hoang 等人,2018)。TransGAN 仅次于 Progressive GAN 和 StyleGAN v2。

对比 FID 结果,研究发现,TransGAN 甚至优于 Progressive GAN,而略低于 StyleGANv2 (Karras 等人,2020b)。在 CIFAR-10 上生成的可视化示例如下图 4 所示:

STL-10 上的结果

研究者将 TransGAN 应用于另一个流行的 48×48 分辨率的基准 STL-10。为了适应目标分辨率,该研究将第一阶段的输入特征图从(8×8)=64 增加到(12×12)=144,然后将提出的 TransGAN-XL 与自动搜索的 ConvNets 和手工制作的 ConvNets 进行了比较,结果下表 6 所示:

与 CIFAR-10 上的结果不同,该研究发现,TransGAN 优于所有当前的模型,并在 IS 和 FID 得分方面达到新的 SOTA 性能。

高分辨率生成

由于 TransGAN 在标准基准 CIFAR-10 和 STL-10 上取得不错的性能,研究者将 TransGAN 用于更具挑战性的数据集 CelebA 64 × 64。
TransGAN-XL 的 FID 评分为 12.23,这表明 TransGAN-XL 可适用于高分辨率任务。可视化结果如图 4 所示。

局限性

虽然 TransGAN 已经取得了不错的成绩,但与最好的手工设计的 GAN 相比,它还有很大的改进空间。在论文的最后,作者指出了以下几个具体的改进方向:
  • 对 G 和 D 进行更加复杂的 tokenize 操作,如利用一些语义分组 (Wu et al., 2020)。

  • 使用代理任务(pretext task)预训练 Transformer,这样可能会改进该研究中现有的 MT-CT。

  • 更加强大的注意力形式,如 (Zhu 等人,2020)。

  • 更有效的自注意力形式 (Wang 等人,2020;Choromanski 等人,2020),这不仅有助于提升模型效率,还能节省内存开销,从而有助于生成分辨率更高的图像。

作者简介
本文一作 Yifan Jiang 是德州大学奥斯汀分校电子与计算机工程系的一年级博士生(此前在德克萨斯 A&M 大学学习过一年),本科毕业于华中科技大学,研究兴趣集中在计算机视觉深度学习等方向。目前,Yifan Jiang 主要从事神经架构搜索、视频理解和高级表征学习领域的研究,师从德州大学奥斯汀分校电子与计算机工程系助理教授 Zhangyang Wang。

在本科期间,Yifan Jiang 曾在字节跳动 AI Lab 实习。今年夏天,他将进入 Google Research 实习。

一作主页:https://yifanjiang.net/

参考链接:https://www.reddit.com/r/MachineLearning/comments/ll30kf/r_transgan_two_transformers_can_make_one_strong/
理论不用卷积纯TransformerGAN
1
相关数据
字节跳动机构

北京字节跳动科技有限公司成立于2012年,是最早将人工智能应用于移动互联网场景的科技企业之一,是中国北京的一家信息科技公司,地址位于北京市海淀区知春路甲48号。其独立研发的“今日头条”客户端,通过海量信息采集、深度数据挖掘和用户行为分析,为用户智能推荐个性化信息,从而开创了一种全新的新闻阅读模式

https://bytedance.com
IBM机构

是美国一家跨国科技公司及咨询公司,总部位于纽约州阿蒙克市。IBM主要客户是政府和企业。IBM生产并销售计算机硬件及软件,并且为系统架构和网络托管提供咨询服务。截止2013年,IBM已在全球拥有12个研究实验室和大量的软件开发基地。IBM虽然是一家商业公司,但在材料、化学、物理等科学领域却也有很高的成就,利用这些学术研究为基础,发明很多产品。比较有名的IBM发明的产品包括硬盘、自动柜员机、通用产品代码、SQL、关系数据库管理系统、DRAM及沃森。

https://www.ibm.com/us-en/
相关技术
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

自注意力技术

自注意力(Self-attention),有时也称为内部注意力,它是一种涉及单序列不同位置的注意力机制,并能计算序列的表征。自注意力在多种任务中都有非常成功的应用,例如阅读理解、摘要概括、文字蕴含和语句表征等。自注意力这种在序列内部执行 Attention 的方法可以视为搜索序列内部的隐藏关系,这种内部关系对于翻译以及序列任务的性能非常重要。

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

表征学习技术

在机器学习领域,表征学习(或特征学习)是一种将原始数据转换成为能够被机器学习有效开发的一种技术的集合。在特征学习算法出现之前,机器学习研究人员需要利用手动特征工程(manual feature learning)等技术从原始数据的领域知识(domain knowledge)建立特征,然后再部署相关的机器学习算法。虽然手动特征工程对于应用机器学习很有效,但它同时也是很困难、很昂贵、很耗时、并依赖于强大专业知识。特征学习弥补了这一点,它使得机器不仅能学习到数据的特征,并能利用这些特征来完成一个具体的任务。

计算机视觉技术

计算机视觉(CV)是指机器感知环境的能力。这一技术类别中的经典任务有图像形成、图像处理、图像提取和图像的三维推理。目标识别和面部识别也是很重要的研究领域。

图像生成技术

图像生成(合成)是从现有数据集生成新图像的任务。

上采样技术

在数字信号处理中,上采样、扩展和内插是与多速率数字信号处理系统中的重采样过程相关的术语。 上采样可以与扩展同义,也可以描述整个扩展和过滤(插值)过程。

生成对抗网络技术

生成对抗网络是一种无监督学习方法,是一种通过用对抗网络来训练生成模型的架构。它由两个网络组成:用来拟合数据分布的生成网络G,和用来判断输入是否“真实”的判别网络D。在训练过程中,生成网络-G通过接受一个随机的噪声来尽量模仿训练集中的真实图片去“欺骗”D,而D则尽可能的分辨真实数据和生成网络的输出,从而形成两个网络的博弈过程。理想的情况下,博弈的结果会得到一个可以“以假乱真”的生成模型。

层归一化技术

深度神经网络的训练是具有高度的计算复杂性的。减少训练的时间成本的一种方法是对神经元的输入进行规范化处理进而加快网络的收敛速度。层规范化是在训练时和测试时对数据同时进行处理,通过对输入同一层的数据进行汇总,计算平均值和方差,来对每一层的输入数据做规范化处理。层规范化是基于批规范化进行优化得到的。相比较而言,批规范化是对一个神经元输入的数据以mini-batch为单位来进行汇总,计算平均值和方法,再用这个数据对每个训练样例的输入进行规整。层规范化在面对RNN等问题的时候效果更加优越,也不会受到mini-batch选值的影响。

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

生成对抗技术

生成对抗是训练生成对抗网络时,两个神经网络相互博弈的过程。两个网络相互对抗、不断调整参数,最终目的是使判别网络无法判断生成网络的输出结果是否真实。

推荐文章
暂无评论
暂无评论~