作者前言
我使用Iris数据集训练了一系列机器学习模型,从数据中的极端值合成了新数据点,并测试了许多机器学习模型来绘制出决策边界,这些模型可根据这些边界在2D空间中进行预测,这对于阐明目的和了解不同机器学习模型如何进行预测会很有帮助。
前沿的机器学习
机器学习模型可以胜过传统的计量经济学模型,这并没有什么新奇的,但是作为研究的一部分,我想说明某些模型为什么以及如何进行分类预测。我想展示我的二分类模型所依据的决策边界,也就是展示数据进行分类预测的分区空间。该问题以及代码经过一些调整也能够适用于多分类问题。
初始化
首先加载一系列程序包,然后新建一个logistic函数,以便稍后将log-odds转换为logistic概率函数。
library(dplyr) library(patchwork) library(ggplot2) library(knitr) library(kableExtra) library(purrr) library(stringr) library(tidyr) library(xgboost) library(lightgbm) library(keras) library(tidyquant) ##################### Pre-define some functions logit2prob <- function(logit){ odds <- exp(logit) prob <- odds / (1 + odds) return(prob) }
数据
我使用的iris数据集包含有关英国统计员Ronald Fisher在1936年收集的3种不同植物变量的信息。该数据集包含4种植物物种的不同特征,这些特征可区分33种不同物种(Setosa,Virginica和Versicolor)。但是,我的问题需要一个二元分类问题,而不是一个多分类问题。在下面的代码中,我导入了iris数据并删除了一种植物物种virginica,以将其从多重分类转变为二元分类问题。
data(iris) df <- iris %>% filter(Species != "virginica") %>% mutate(Species = +(Species == "versicolor")) str(df) ## 'data.frame': 100 obs. of 5 variables: ## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... ## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ... ## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... ## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... ## $ Species : int 0 0 0 0 0 0 0 0 0 0 ...
我首先采用ggplot来绘制数据,以下储存的ggplot对象中,每个图仅更改x和y变量选择。
plt1 <- df %>% ggplot(aes(x = Sepal.Width, y = Sepal.Length, color = factor(Species))) + geom_point(size = 4) + theme_bw(base_size = 15) + theme(legend.position = "none") plt2 <- df %>% ggplot(aes(x = Petal.Length, y = Sepal.Length, color = factor(Species))) + geom_point(size = 4) + theme_bw(base_size = 15) + theme(legend.position = "none") plt3 <- df %>% ggplot(aes(x = Petal.Width, y = Sepal.Length, color = factor(Species))) + geom_point(size = 4) + theme_bw(base_size = 15) + theme(legend.position = "none") plt3 <- df %>% ggplot(aes(x = Sepal.Length, y = Sepal.Width, color = factor(Species))) + geom_point(size = 4) + theme_bw(base_size = 15) + theme(legend.position = "none") plt4 <- df %>% ggplot(aes(x = Petal.Length, y = Sepal.Width, color = factor(Species))) + geom_point(size = 4) + theme_bw(base_size = 15) + theme(legend.position = "none") plt5 <- df %>% ggplot(aes(x = Petal.Width, y = Sepal.Width, color = factor(Species))) + geom_point(size = 4) + theme_bw(base_size = 15) + theme(legend.position = "none") plt6 <- df %>% ggplot(aes(x = Petal.Width, y = Sepal.Length, color = factor(Species))) + geom_point(size = 4) + theme_bw(base_size = 15) + theme(legend.position = "none")
我还使用了新的patchwork 包,使展示ggplot结果变得很容易。下面的代码很直白的绘制了我们的图形(1个顶部图占满了网格空间的长度,2个中等大小的图,另一个单个图以及底部另外2个图)
(plt1) / (plt2 + plt3)
或者,我们可以将绘图重新布置为所需的任何方式,并通过以下方式进行绘图:
(plt1 + plt2) / (plt5 + plt6)
我觉得这看起来不错。