Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

周志华等提出 RNN 可解释性方法,看看 RNN 内部都干了些什么

除了数值计算,你真的知道神经网络内部在做什么吗?我们一直理解深度模型都靠里面的运算流,但对于是不是具有物理意义、语义意义都还是懵懵懂懂。尤其是在循环神经网络中,我们只知道每一个时间步它都在利用以前的记忆抽取当前语义信息,但具体到怎么以及什么的时候,我们就无能为力了。在本文中,南京大学的周志华等研究者尝试利用有限状态机探索 RNN 的内在机制,这种具有物理意义的模型可以将 RNN 的内部流程展现出来,并帮助我们窥探 RNN 到底都干了些什么。

结构化学习(Structure learning)的主要任务是处理结构化的输出,它不像分类问题那样为每个独立的样本预测一个值。这里所说的结构可以是图、序列、树形结构和向量等。一般用于结构化输出的机器学习算法有各种概率图模型感知机和 SVM 等。在过去的数十年里,结构化学习已经广泛应用于目标追踪、目标定位和语义解析等任务,而多标签学习和聚类等很多问题同样与结构化学习有很强的关联。

一般来说,结构化学习会使用结构化标注作为监督信息,并借助相应的算法来预测这些结构化信息而实现优良的性能。然而,随着机器学习算法变得越来越复杂,它们的可解释性则变得越来越重要,这里的可解释性指的是如何理解学习过程的内在机制或内部流程。在这篇论文中,周志华等研究者重点关注深度学习模型,并探索如何学习这些模型的结构以提升模型可解释性。

探索深度学习模型的可解释性通常都比较困难,然而对于 RNN 等特定类型的深度学习模型,我们还是有方法解决的。循环神经网络(RNN)作为深度神经网络中的主要组成部分,它们在各种序列数据任务中有非常广泛的应用,特别是那些带有门控机制的变体,例如带有一个门控的 MGU、带有两个门控的 GRU 和三个门控的 LSTM

除了我们熟悉的 RNN 以外,还有另一种工具也能捕捉序列数据,即有限状态机(Finite State Automaton/FSA)。FSA 由有限状态和状态之间的转换组成,它将从一个状态转换为另一个状态以响应外部序列输入。FSA 的转换过程有点类似于 RNN,因为它们都是一个一个接收序列中的输入元素,并在相应的状态间传递。与 RNN 不同的是,FSA 的内部机制更容易被解释,因为我们更容易模拟它的过程。此外,FSA 在状态间的转换具有物理意义,而 RNN 只有数值计算的意义。

FSA 的这些特性令周志华团队探索从 RNN 中学习一个 FSA 模型,并利用 FSA 的天然可解释性能力来理解 RNN 的内部机制,因此周志华等研究者采用 FSA 作为他们寻求的可解释结构。此外,这一项研究与之前关于结构化学习的探索不同。之前的方法主要关注结构化的预测或分类结果,这一篇文章主要关注中间隐藏层的输出结构,这样才能更好地理解 RNN 的内在机制。

为了从 RNN 中学习 FSA,并使用 FSA 解释 RNN 的内在机制,我们需要知道如何学习 FSA 以及具体解释 RNN 中的什么。对于如何学习 FSA,研究者发现非门控的经典 RNN 隐藏状态倾向于构造一些集群。但是仍然存在一些重要的未解决问题,其中之一是我们不知道构造集群的倾向在门控 RNN 中是否也存在。我们同样需要考虑效率问题,因为门控 RNN 变体通常用于大型数据集中。对于具体解释 RNN 中的什么,研究者分析了门控机制在 LSTM、GRU 和 MGU 等模型中的作用,特别是不同门控 RNN 中门的数量及其影响。鉴于 FSA 中状态之间的转换是有物理意义的,因此我们可以从与 RNN 对应的 FSA 转换推断出语义意义。

在这篇论文中,周志华等研究者尝试从 RNN 学习 FSA,他们首先验证了除不带门控的经典 RNN 外,其它门控 RNN 变体的隐藏状态同样也具有天然的集群属性。然后他们提出了两种方法,其一是高效的聚类方法 k-means++。另外一种方法根据若相同序列中隐藏状态相近,在几何空间内也相近的现象而提出,这一方法被命名为 k-means-x。随后研究者通过设计五个必要的元素来学习 FSA,即字母表、一组状态、初始状态、一组接受状态和状态转换,他们最后将这些方法应用到了模拟数据和真实数据中。

对于人工模拟数据,研究者首先表示我们可以理解在运行过程学习到的 FSA。然后他们展示了准确率和集群数量之间的关系,并表示门控机制对于门控 RNN 是必要的,并且门越少越好。这在一定程度上解释了为什么只有一个门控的 MGU 在某种程度上优于其它门控 RNN。

对于情感分析这一真实数据,研究者发现在数值计算的背后,RNN 的隐藏状态确实具有区分语义差异性的能力。因为在对应的 FSA 中,导致正类 / 负类输出的词确实在做一些正面或负面的人类情感理解。

论文:Learning with Interpretable Structure from RNN

论文地址:https://arxiv.org/pdf/1810.10708.pdf

摘要:在结构化学习中,输出通常是一个结构,可以作为监督信息用于获取良好的性能。考虑到深度学习可解释性在近年来受到了越来越多的关注,如果我们能重深度学习模型中学到可解释的结构,将是很有帮助的。在本文中,我们聚焦于循环神经网络(RNN),它的内部机制目前仍然是没有得到清晰的理解。我们发现处理序列数据的有限状态机(FSA)有更加可解释的内部机制,并且可以从 RNN 学习出来作为可解释结构。我们提出了两种不同的聚类方法来从 RNN 学习 FSA。我们首先给出 FSA 的图形,以展示它的可解释性。从 FSA 的角度,我们分析了 RNN 的性能如何受到门控数量的影响,以及数值隐藏状态转换背后的语义含义。我们的结果表明有简单门控结构的 RNN 例如最小门控单元(MGU)的表现更好,并且 FSA 中的转换可以得到和对应单词相关的特定分类结果,该过程对于人类而言是可理解的。

本文的方法

在这一部分,我们介绍提出方法的直觉来源和方法框架。我们将 RNN 的隐藏状态表示为一个向量或一个点。因此当多个序列被输入到 RNN 时,会积累大量的隐藏状态点,并且它们倾向于构成集群。为了验证该结论,我们展示了在 MGU、SRU、GRU 和 LSTM 上的隐藏状态点的分布,如图 1(a)到(d)所示。

图 1:隐藏状态点由 t-SNE 方法降维成 2 个维度,我们可以看到隐藏状态点倾向于构成集群。

图 2 展示了整个框架。我们首先在训练数据集上训练 RNN,然后再对应验证数据 V 的所有隐藏状态 H 上执行聚类,最后学习一个关于 V 的 FSA。再得到 FSA 后,我们可以使用它来测试未标记数据或直接画出图示。再训练 RNN 的第一步,我们利用了和 [ZWZZ16] 相同的策略,在这里忽略了细节。之后,我们会详细介绍隐藏状态聚类和 FSA 学习步骤(参见原文)。

图 2:本文提出算法的框架展示。黄色圆圈表示隐藏状态,由 h_t 表示,这里 t 是时间步。「A」是循环单元,接收输入 x_t 和 h_t-1 并输出 h_t。结构化 FSA 的双圆圈是接受状态。总体来说,该框架由三个步骤构成,即训练 RN 你模型、聚类隐藏状态和输出结构化 FSA。

完整的从 RNN 学习 FSA 的过程如算法 1 所示。我们将该方法称为 LISOR,并展示了两种不同的聚类算法。基于 k-means++ 的被称为「LISOR-k」,基于 k-means-x 的被称为「LISOR-x」。通过利用构成隐藏状态点的聚类倾向,LISOR-k 和 LISOR-x 都可以从 RNN 学习到良好泛化的 FSA。

实验结果

在这一部分,我们在人工和真实任务上进行了实验,并可视化了从对应 RNN 模型学习到的 FSA。除此之外,在两个任务中,我们讨论了我们如何从 FSA 解释 RNN 模型,以及展示使用学习到的 FSA 来做分类的准确率

第一个人工任务是在一组长度为 4 的序列中(只包含 0 和 1)识别序列「0110」(任务「0110」). 这是一个简单的只包含 16 个不同案例的任务。我们在训练集中包含了 1000 个实例,通过重复实例来提高准确率。我们使用包含所有可能值且没有重复的长度为 4 的 0-1 序列来学习 FSA,并随机生成 100 个实例来做测试。

第二个人工任务是确定一个序列是否包含三个连续的 0(任务「000」)。这里对于序列的长度没有限制,因此该任务有无限的实例空间,并且比任务「0110」更困难。我们随机生成 3000 个 0-1 训练实例,其长度是随机确定的。我们还生成了 500 个验证实例和 500 个测试实例。

表 2:分别基于 LISOR-k 和 LISOR-x 方法,当从 4 个 RNN 中学习到的 FSA 在任务「0110」的准确率首次达到 1.0 时,集群的数量(n_c)。注意这些值是越小越高效,并且可解释性越好。不同试验中训练得到的 RNN 模型使用了不同的参数初始化。

如表 2 所示,我们可以看到在从 MGU 学习到的 FSA 的平均集群数量总是能以最小的集群数量达到准确率 1.0。集群数量为 65 意味着 FSA 的准确率在直到 n_c 为 64 时都无法达到 1.0。每次试验的最小集群数量和平均最小集群数量加粗表示。

表 3:分别基于 LISOR-k 和 LISOR-x 方法,当从 4 个 RNNzho 中学习到的 FSA 在任务「000」的准确率首次达到 0.7 时,集群的数量(n_c)。注意这些值是越小越高效,并且可解释性越好。

图 3:在任务「0110」训练 4 个 RNN 时学习得到的 FSA 结构图示。集群数量 k 由 FSA 首次达到准确率 1.0 时的聚类数量决定。0110 的路由用红色表示。注意在图(d)中由 4 个独立于主要部分的节点。这是因为我们舍弃了当输入一个符号来学习一个确定性 FSA 时更小频率的转换。

图 4:在任务「000」训练 4 个 RNN 时学习得到的 FSA 结构图示。集群数量 k 由 FSA 首次达到准确率 0.7 时的集群数量决定。

图 7:在情感分析任务上训练的 MGU 学习到的 FSA。这里的 FSA 经过压缩,并且相同方向上的相同两个状态之间的词被分成同一个词类。例如,「word class 0-1」中的词全部表示从状态 0 转换为状态 1。

表 5:从状态 0 转换的词称为可接受状态(即包含积极电影评论的状态 1),其中大多数词都是积极的。这里括号中的数字表示词来源的 FSA 编号。

表 4:当集群数量为 2 时情感分类任务的准确率

理论南京大学周志华RNN结构化数据门控循环单元
111
相关数据
周志华人物

周志华分别于1996年6月、1998年6月和2000年12月于 南京大学计算机科学与技术系获学士、硕士和博士学位。主要从事人工智能、机器学习、数据挖掘 等领域的研究工作。主持多项科研课题,出版《机器学习》(2016)与《Ensemble Methods: Foundations and Algorithms》(2012),在一流国际期刊和顶级国际会议发表论文百余篇,被引用三万余次。

深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

结构学习技术

结构化预测是监督学习,分类和回归的标准范式的一种推广。 所有这些可以被认为是找到一个能最大限度减少训练集损失的函数。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

感知技术

知觉或感知是外界刺激作用于感官时,脑对外界的整体的看法和理解,为我们对外界的感官信息进行组织和解释。在认知科学中,也可看作一组程序,包括获取信息、理解信息、筛选信息、组织信息。与感觉不同,知觉反映的是由对象的各样属性及关系构成的整体。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

有限状态机技术

有限状态机(英语:finite-state machine,缩写:FSM)又称有限状态自动机,简称状态机,是表示有限个状态以及在这些状态之间的转移和动作等行为的数学模型。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

t分布随机邻嵌入技术

t分布随机邻嵌入(t-SNE)是由Geoffrey Hinton和Laurens van der Maaten 开发的一种降维的机器学习算法。 这是一种非线性降维技术,特别适合将高维数据嵌入到二维或三维空间,然后可以在散点图中将其可视化。 具体来说,它通过二维或三维点对每个高维对象进行建模,使得类似的对象由附近的点建模,不相似的对象由远点建模。

分类问题技术

分类问题是数据挖掘处理的一个重要组成部分,在机器学习领域,分类问题通常被认为属于监督式学习(supervised learning),也就是说,分类问题的目标是根据已知样本的某些特征,判断一个新的样本属于哪种已知的样本类。根据类别的数量还可以进一步将分类问题划分为二元分类(binary classification)和多元分类(multiclass classification)。

降维技术

降维算法是将 p+1 个系数的问题简化为 M+1 个系数的问题,其中 M<p。算法执行包括计算变量的 M 个不同线性组合或投射(projection)。然后这 M 个投射作为预测器通过最小二乘法拟合一个线性回归模型。两个主要的方法是主成分回归(principal component regression)和偏最小二乘法(partial least squares)。

长短期记忆网络技术

长短期记忆(Long Short-Term Memory) 是具有长期记忆能力的一种时间递归神经网络(Recurrent Neural Network)。 其网络结构含有一个或多个具有可遗忘和记忆功能的单元组成。它在1997年被提出用于解决传统RNN(Recurrent Neural Network) 的随时间反向传播中权重消失的问题(vanishing gradient problem over backpropagation-through-time),重要组成部分包括Forget Gate, Input Gate, 和 Output Gate, 分别负责决定当前输入是否被采纳,是否被长期记忆以及决定在记忆中的输入是否在当前被输出。Gated Recurrent Unit 是 LSTM 众多版本中典型的一个。因为它具有记忆性的功能,LSTM经常被用在具有时间序列特性的数据和场景中。

概率图模型技术

在概率论和统计学中,概率图模型(probabilistic graphical model,PGM) ,简称图模型(graphical model,GM),是指一种用图结构来描述多元随机 变量之间条件独立关系的概率模型

深度神经网络技术

深度神经网络(DNN)是深度学习的一种框架,它是一种具备至少一个隐层的神经网络。与浅层神经网络类似,深度神经网络也能够为复杂非线性系统提供建模,但多出的层次为模型提供了更高的抽象层次,因而提高了模型的能力。

推荐文章
文章中图片只能显示左边一半,很影响阅读,希望能修改下排版