Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

针对分类问题的准确度悖论

分类问题机器学习的研究重点,而后者在实践中常常碰到非均衡数据集这个难题。非均衡数据集(imbalanced data)又称为非平衡数据集,指的是针对分类问题,数据集中各个类别所占比例并不平均。比如在网络广告行业,需要对用户是否点击网页上的广告进行建模。为了处理方便,我们记“点击广告”为类别1,“不点击广告”为类别0。因此这是一个二元分类问题。在训练模型的历史数据里有1000个数据点(1000行),其中类别1的数据点只有10个,剩下的990个数据全部为类别0。这就是一个非均衡数据集,类别之间的比例为99:1。与二元分类问题类似,多元分类问题同样会面对非均衡数据集这个难题。不过在这个问题上,多元分类的处理的方案与二元的相似,因此为了表述简洁利于理解,下面的讨论将针对二元分类问题

非均衡数据集在现实中是十分常见的。它给模型搭建带来了困难,如果不小心处理,会导致得到的模型结果毫无意义。在讨论这个话题之前,让我们稍稍离题一下,来看看所谓的准确度悖论(accuracy paradox)。

一、准确度悖论

对于二元分类问题,模型的预测结果按准确与否可以分为如下4类,见表1

1


真实值

1

0

预测值

1

真阳性(True positive

TP

伪阳性(False positive

FP

0

伪阴性(False negative

FN

真阴性(True negative

TN


其中,TPTN这两个部分都表示模型的预测结果是正确的,这两者之和的比例越高,说明模型的效果越好。由此可以定义评估模型效果的指标——准确度(accuraryACC)。

准确度这个指标看似很合理,但面对非均衡数据集时,这个指标会严重失真,甚至变得毫无意义。来看下面这个例子:数据集里有1000个数据点,其中990个为类别0,而剩下的10个为类别1,如图1所示。

图1

模型A对所有数据的预测都是类别0,因此这个模型其实并没有提供什么预测功能。但它的准确度却高达99%。模型B的预测效果其实很不错:对于类别110个数据里有9个预测正确;而对于类别0990个数据里有900个预测正确,但它的准确度只有90.9%远低于模型A

这就是所谓的准确度悖论:面对非均衡数据集时,准确度这个评估指标会使模型严重偏向占比更多的类别,导致模型的预测功能失效。

二、一个例子

非均衡数据集除了会引起准确度悖论外,它对搭建模型有什么影响呢?下面通过一个简单的例子来说明这个问题。我们按公式(2)产生模型数据,其中变量为因变量;为自变量;为随机扰动项,它服从逻辑分布。

由此可见,产生的模型数据完美地符合逻辑回归模型的假设。因此使用逻辑回归对数据建模,得到结果按理说应该非常好。但事实上,当数据集是均衡时,也就是说类别1所占比例大约为0.5时,模型效果是还不错。但当类别1所占比例接近0时,也就是数据集是非均衡时,模型的效果就很差了。虽然数据集里类别1的个数不变,但模型的预测结果几乎都是类别0,如图2a所示。正如上面讨论的那样,ACC这个指标在非均衡数据集里会失真,而AUC则可以保持稳定,能正确衡量模型的好坏,如图2b所示。

图2图2

上面的例子从直观上展示了非均衡数据集对搭建模型的影响。那么造成这种结果的原因是什么呢?从数学角度来讲,逻辑回归参数的估算公式如下:(惩罚项并不影响这里的讨论,因此我们在此省略掉惩罚项。)

在这个公式里,每个数据点的权重都是一样的,都为1。也就是说,模型对于类别1所承受的损失为:。这个值几乎等于模型对于类别0所承受的损失:。如果某一类别的数据特别多,不妨假定为类别0,那么在类别1某点的附件,极有可能存在大量的类别0。在这种情况下,根据公式(3),模型会选择“牺牲”类别1,从而导致预测结果几乎都为类别0

上面的结论并不只针对逻辑回归这个分类模型,对于其他分类模型,也同样成立。

三、解决方法

针对非均衡数据集,最常见也是最方便的解决方案是修改损失函数里不同类别的权重。以逻辑回归为例,将它的损失函数改写为公式(4)。

当类别1所占比例很少时,则增加,也就是增加模型对于类别1所承受的损失,反之亦然。在大多数情况下,类别权重的选择原则是,类别权重等于类别所占比例的倒数,如程序清单1中第78行代码所示。经过权重调整后,训练模型的数据集相当于回到了均衡状态。权重调整的代码非常简单,如第10行代码所示,通过“class_weight参数调整各个类别的权重。事实上,“class_weight”也可以被赋值为“balanced”,即“class_weight= 'balanced'”,这时模型会自动调整各个类别的权重

程序清单1 非平衡数据集

 1  |  from sklearn.linear_model import LogisticRegression

 2  |  

 3  |  def balanceData(X, Y):

 4  |       """

 5  |      过调整各个类别的比重,解决非均衡数据集的问题

 6  |       """

 7  |       positiveWeight = len(Y[Y>0]) / float(len(Y))

 8  |       classWeight = {1: 1. / positiveWeight, 0: 1. / (1 - positiveWeight)}

 9  |      # 了消除惩罚项的干,将惩罚系数设为很大

10  |      model = LogisticRegression(class_weight=classWeight, C=1e4)

11  |      model.fit(X, Y.ravel())

12  |      pred = model.predict(X)

13  |       return pred

经过权重调整后,模型的结果如图3所示。在处理非均衡数据集时,调整权重后的模型会错误地将很多类别0的数据预测为类别1。这与我们之前在上一节里的分析是一致的:类别1权重增加后,模型会因“刻意地珍惜”类别1,而选择“牺牲”类别0。尽管如此,调整之后的整体效果明显优于调整之前的(调整之后的AUC更大)。值得注意的是,ACCAUC这两个评估指标几乎相等,所以图形上它们两者重叠在了一起。

对于非均衡数据集,还有一些其他的解决方法,比如通过重新抽样(sampling),把多的类别变少或把少的类别变多。具体的细节在此就不做展开讨论了。

图3

注:这篇文章的大部分内容参考自我的新书《精通数据科学:从线性回归深度学习

唐亘的专栏
唐亘的专栏

唐亘,数据科学家,《精通数据科学:从线性回归到深度学习》一书作者。热爱并积极参与Apache Spark、 scikit-learn等开源项目。作为讲师和技术顾问为多家机构(包括惠普,华为,复旦大学等)提供百余场技术培训。

理论数据科学分类
2
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

逻辑回归技术

逻辑回归(英语:Logistic regression 或logit regression),即逻辑模型(英语:Logit model,也译作“评定模型”、“分类评定模型”)是离散选择法模型之一,属于多重变量分析范畴,是社会学、生物统计学、临床、数量心理学、计量经济学、市场营销等统计实证分析的常用方法。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

真正例技术

被模型正确地预测为正类别的样本。例如,模型推断出某封电子邮件是垃圾邮件,而该电子邮件确实是垃圾邮件。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

线性回归技术

在现实世界中,存在着大量这样的情况:两个变量例如X和Y有一些依赖关系。由X可以部分地决定Y的值,但这种决定往往不很确切。常常用来说明这种依赖关系的最简单、直观的例子是体重与身高,用Y表示他的体重。众所周知,一般说来,当X大时,Y也倾向于大,但由X不能严格地决定Y。又如,城市生活用电量Y与气温X有很大的关系。在夏天气温很高或冬天气温很低时,由于室内空调、冰箱等家用电器的使用,可能用电就高,相反,在春秋季节气温不高也不低,用电量就可能少。但我们不能由气温X准确地决定用电量Y。类似的例子还很多,变量之间的这种关系称为“相关关系”,回归模型就是研究相关关系的一个有力工具。

逻辑技术

人工智能领域用逻辑来理解智能推理问题;它可以提供用于分析编程语言的技术,也可用作分析、表征知识或编程的工具。目前人们常用的逻辑分支有命题逻辑(Propositional Logic )以及一阶逻辑(FOL)等谓词逻辑。

分类问题技术

分类问题是数据挖掘处理的一个重要组成部分,在机器学习领域,分类问题通常被认为属于监督式学习(supervised learning),也就是说,分类问题的目标是根据已知样本的某些特征,判断一个新的样本属于哪种已知的样本类。根据类别的数量还可以进一步将分类问题划分为二元分类(binary classification)和多元分类(multiclass classification)。

推荐文章
暂无评论
暂无评论~