Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Goodfellow点赞的相对鉴别器:标准GAN中缺失的关键因素

Ian Goodfellow 刚刚评论了一篇 GAN 论文,他认为这一篇关于相对 GAN 的论文有非常好的属性。Goodfellow 在小型数据集上尝试了这种相对 GAN,并有很好的效果。这种相对 GAN 基于非常朴素的概念:在训练中 GAN 应该同时降低真实数据看起来为真的概率。为此该论文提出了相对鉴别器,并在给定真实数据下估计它们比随机采样的假数据要真实的概率。

生成对抗网络(GAN)[Hong et al., 2017] 是生成模型的一大类别,两个竞争的神经网络——鉴别器 D 和生成器 G 在其中玩游戏。训练 D 用于分辨数据的真假,而 G 用于生成可以被 D 误识别为真数据的假数据。在 Goodfellow 等 [2014] 提出的原始 GAN(我们称之为标准 GAN,即 SGAN)中,D 是分类器,用于预测输入数据为真的概率。如果 D 达到最佳状态,SGAN 的损失函数就会近似于 JS 散度(Jensen–Shannon divergence,JSD)[Goodfellow et al., 2014]。

SGAN 有两种生成损失函数变体:饱和的和非饱和的。实践证明,前者非常不稳定,而后者则稳定得多 [Goodfellow et al., 2014]。Arjovsky 和 Bottou[2017] 证明,在某些条件下,如果能够将真假数据完美地分类,饱和损失函数的梯度为0,而非饱和损失函数的梯度不为 0,且不稳定。在实践中,这意味着 SGAN 中的鉴别器通常训练效果不佳;否则梯度就会消失,训练也随之停止。这一问题在高维设定中会更加明显(如高分辨率图像及具有较高表达能力的鉴别器架构),因为在这种设定下,实现训练集完美分类的自由度更高。

为了提升 SGAN,许多 GAN 变体可以选择使用不同的损失函数及非分类器的鉴别器(如 LSGAN[Mao et al., 2017]、WGAN [Arjovsky et al., 2017])。尽管这些方法适当提升了稳定性和数据质量,但 Lucic 等人做的大型研究 [2017] 表明,这些方法在 SGAN 上并没有持续改进。此外,一些非常成功的的方法(如 WGAN-GP [Gulrajani et al., 2017])对计算的要求比 SGAN 高得多。

最近许多成功的 GAN 都是基于积分概率度量(Integral Probability Metric,IPM)[Müller, 1997](如 WGAN [Arjovsky et al., 2017]、WGAN-GP[Gulrajani et al., 2017]、Sobolev GAN [Mroueh et al., 2017]、Fisher GAN [Mroueh and Sercu, 2017])。在基于 IPM 的 GAN 中,鉴别器是实值的,并被限制在一类特定的函数中,以免增长过快;这是一种正则化形式,防止 D 变得过强(即大致将真假数据完美分类)。在实践中,我们发现基于 IPM 的 GAN 鉴别器可以经过多次迭代训练而不造成梯度消失。

IPM 限制已被证明在不基于 IPM 的 GAN 中同样有益。WGAN 限制(即 Lipschitz 鉴别器)已通过谱归一化被证明在其他 GAN 中也有帮助 [Miyato et al., 2018]。WGAN-GP 限制(即真假数据梯度范数等于 1 的鉴别器)被证明在 SGAN 中有益 [Fedus et al., 2017](以及 Kodali 等人非常相似的梯度罚分 [ 2017 ])。

尽管这表明某些 IPM 限制会提高 GAN 的稳定性,但这并不能解释为什么 IPM 所提供的稳定性通常比 GAN 中的其他度量/散度提供的更高(如 SGAN 的 JSD、f-GAN 的 f-divergences[Nowozin et al., 2016])。本文认为,不基于 IPM 的 GAN 缺失一个关键元素——一个相对鉴别器,而基于 IPM 的 GAN 则拥有该辨别器。研究表明,为了使 GAN 接近散度最小化,并根据小批量样本中有一半为假这一先验知识产生合理的预测,相对鉴别器是必要的。论文提供的经验证据表明,带有相对鉴别器的 GAN 更稳定,产生的数据质量也更高。

论文:The relativistic discriminator: a key element missing from standard GAN

论文地址:https://arxiv.org/abs/1807.00734

在标准生成对抗网络(SGAN)中,鉴别器 D 用于估计输入数据为真实样本的概率,而生成器 G 用于提高数据以假乱真的概率。我们认为它应该同时降低真实数据看起来为真的概率,因为 1)这可以解释批量数据中一半为假的先验知识,2)我们可以在最小化散度的过程中观察到这种现象,3)在最优设定中,SGAN 等价于积分概率度量(IPM)GAN。我们证明该属性可以通过使用一个「相对鉴别器」(Relativistic Discriminator)导出,该鉴别器在给定真实数据下估计它们比随机采样的假数据要真实的概率。

我们还提出了一种变体,其中鉴别器估计平均给定的真实数据要比假数据更加真实的概率。我们泛化两种方法到非标准 GAN 损失函数中,并分别称之为相对 GAN(RGAN)和相对平均 GAN(RaGAN)。我们的研究表明,基于 IPM 的 GAN 是使用恒等函数的 RGAN 的子集。实验中,我们观察到 1)与非相对 GAN 相比,RGAN 和 RaGAN 生成的数据样本更稳定且质量更高。2)与 WGAN-GP 相比,带有梯度惩罚的标准 RaGAN 生成的数据质量更高,同时每个生成器的更新还只要求单个鉴别器更新,这将达到当前最优性能的时间降低到原来的 1/4。3)RaGAN 能从非常小的样本(N=2011)生成高分别率的图像(256×256),而 GAN 与 LSGAN 都不能。此外,这些图像也显著优于 WGAN-GP 和带谱归一化的 SGAN 所生成的图像。

4 方法

4.2 相对 GAN

更一般的,我们考虑了由 a(C(x_r)−C(x_f )) 定义的任意鉴别器,其中 a 为激活函数,它因为输入 C(x_r)−C(x_f ) 而变得具有相对性。这意味着基本上任意 GAN 都可以添加一个相对鉴别器。这能组成新一类的模型,我们称之为相对 GAN(Relativistic GAN/RGAN)。

大多数 GAN 可以在 critic 方面做非常普遍的参数化:

其中 f_1、f_2、g_1、g_2 都是标量到标量的函数。如果我们使用一个相对鉴别器,那么 GAN 现在就可以表示为以下形式:

基于 IPM 的 GAN 代表了 RGAN 的特例,其中 f_1(y) = g_2(y) = −y、f_2(y) = g_1(y) = y。重要的是,g_1 一般在 GAN 中是忽略的,因为它的梯度为 0,且生成器并不能影响它。然而在 RGAN 中,g_1 受到了假数据的影响,所以受到了生成器的影响。因此 g_1 一般有非零的梯度且需要在生成器损失中指定。这意味着在大多数 RGAN(除了基于 IPM 的 GAN,因为它们使用恒等函数)中,我们需要训练生成器以最小化预期的总体损失函数,而不仅仅只是它的一半。

算法 1 展示了训练 RGAN 的过程:

5 实验

表 1:传统定义的 GAN 鉴别器(P(x_r is real) = sigmoid(C(x_r)))与相对平均鉴别器(P(x_r is real|C(x_f )) = sigmoid(C(x_r) − C(x_f )))的输出样本。其中面包表示真实图像、小狗表示伪造图像。

表 3:在 CIFAR-10 数据集上执行 100k 次生成器迭代所得出的 Fréchet Inception 距离(FID),它使用不同 GAN 损失函数的不稳定的配置。

表 4:在 CAT 数据集和不同的 GAN 损失函数上执行 20k、30k 到 100k 生成器迭代后的 Fréchet Inception 距离(FID),其中 min、max、mean 和 SD 分别表示 FID 的最大、最小、平均、标准差值。

在 Ian Goodfellow 对该论文的评论中,他非常关注附录所展示出来的生成器训练速度。在一般的 GAN 训练中,我们通常会发现生成器在初始化后训练地非常慢,它要经过很多次迭代才开始不再生成噪声。而在这一篇论文中,作者表示 GAN 和 LSGAN 在 CAT 数据集上迭代 5000 次仍然只能生成如下所示 256×256 的噪声

而 RaSGAN 在初始化后就能快速学习生成图像。


理论生成对抗网络生成模型损失函数
4
相关数据
范数技术

范数(norm),是具有“长度”概念的函数。在线性代数、泛函分析及相关的数学领域,是一个函数,其为向量空间内的所有向量赋予非零的正长度或大小。半范数反而可以为非零的向量赋予零长度。

激活函数技术

在 计算网络中, 一个节点的激活函数定义了该节点在给定的输入或输入的集合下的输出。标准的计算机芯片电路可以看作是根据输入得到"开"(1)或"关"(0)输出的数字网络激活函数。这与神经网络中的线性感知机的行为类似。 一种函数(例如 ReLU 或 S 型函数),用于对上一层的所有输入求加权和,然后生成一个输出值(通常为非线性值),并将其传递给下一层。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

先验知识技术

先验(apriori ;也译作 先天)在拉丁文中指“来自先前的东西”,或稍稍引申指“在经验之前”。近代西方传统中,认为先验指无需经验或先于经验获得的知识。先验知识不依赖于经验,比如,数学式子2+2=4;恒真命题“所有的单身汉一定没有结婚”;以及来自纯粹理性的推断“本体论证明”

规范化技术

规范化:将属性数据按比例缩放,使之落入一个小的特定区间,如-1.0 到1.0 或0.0 到1.0。 通过将属性数据按比例缩放,使之落入一个小的特定区间,如0.0到1.0,对属性规范化。对于距离度量分类算法,如涉及神经网络或诸如最临近分类和聚类的分类算法,规范化特别有用。如果使用神经网络后向传播算法进行分类挖掘,对于训练样本属性输入值规范化将有助于加快学习阶段的速度。对于基于距离的方法,规范化可以帮助防止具有较大初始值域的属性与具有较小初始值域的属相相比,权重过大。有许多数据规范化的方法,包括最小-最大规范化、z-score规范化和按小数定标规范化。

生成模型技术

在概率统计理论中, 生成模型是指能够随机生成观测数据的模型,尤其是在给定某些隐含参数的条件下。 它给观测值和标注数据序列指定一个联合概率分布。 在机器学习中,生成模型可以用来直接对数据建模(例如根据某个变量的概率密度函数进行数据采样),也可以用来建立变量间的条件概率分布。

生成对抗网络技术

生成对抗网络是一种无监督学习方法,是一种通过用对抗网络来训练生成模型的架构。它由两个网络组成:用来拟合数据分布的生成网络G,和用来判断输入是否“真实”的判别网络D。在训练过程中,生成网络-G通过接受一个随机的噪声来尽量模仿训练集中的真实图片去“欺骗”D,而D则尽可能的分辨真实数据和生成网络的输出,从而形成两个网络的博弈过程。理想的情况下,博弈的结果会得到一个可以“以假乱真”的生成模型。

WGAN技术

就其本质而言,任何生成模型的目标都是让模型(习得地)的分布与真实数据之间的差异达到最小。然而,传统 GAN 中的判别器 D 并不会当模型与真实的分布重叠度不够时去提供足够的信息来估计这个差异度——这导致生成器得不到一个强有力的反馈信息(特别是在训练之初),此外生成器的稳定性也普遍不足。 Wasserstein GAN 在原来的基础之上添加了一些新的方法,让判别器 D 去拟合模型与真实分布之间的 Wasserstein 距离。Wassersterin 距离会大致估计出「调整一个分布去匹配另一个分布还需要多少工作」。此外,其定义的方式十分值得注意,它甚至可以适用于非重叠的分布。

推荐文章
暂无评论
暂无评论~