用R软件做分类树和回归树(CART)
决策树(Decision Tree)又称为判定树,是运用于分类的一种树结构。其中的每个内部结点(internal node)代表对某个属性的一次测试,每条边代表一个测试结果,叶结点(leaf)代表某个类(class)或者类的分布(class distribution),最上面的结点是根结点。决策树提供了一种展示类似在什么条件下会得到什么值这类规则的方法。
构造决策树是采用自上而下的递归构造方法。以多叉树为例,如果一个训练数据集中的数据有几种属性值,则按照属性的各种取值把这个训练数据集再划分为对应的几个子集(分支),然后再依次递归处理各个子集。反之,则作为叶结点。
决策树构造的结果是一棵二叉或多叉树,它的输入是一组带有类别标记的训练数据。二叉树的内部结点(非叶结点)一般表示为一个逻辑判断,如形式为(a = b)的逻辑判断,其中a 是属性,b是该属性的某个属性值;树的边是逻辑判断的分支结果。多叉树(ID3)的内部结点是属性,边是该属性的所有取值,有几个属性值,就有几条边。树的叶结点都是类别标记。
使用决策树进行分类分为两步:
第1步:利用训练集建立并精化一棵决策树,建立决策树模型。这个过程实际上是一个从数据中获取知识,进行机器学习的过程。
第2步:利用生成完毕的决策树对输入数据进行分类。对输入的记录,从根结点依次测试记录的属性值,直到到达某个叶结点,从而找到该记录所在的类。
问题的关键是建立一棵决策树。这个过程通常分为两个阶段:
第一阶段,建树(Tree Building):决策树建树算法见下,这是一个递归的过程,最终将得到一棵树。
第二阶段,剪枝(Tree Pruning):剪枝的目的是降低由于训练集存在噪声而产生的起伏。
分类树和回归树(CART:Classification and Regression)
描述给定预测变量X后,变量Y条件分布的一种方法,使用二叉树将预测空间递归地划分为若干个子集,Y在这些子集上的分布是连续均匀的,树中的叶节点对应着划分的不同区域,划分是由与每个内部节点相关的分支规则(Splitting rules)确定的.通过从树的根节点逐渐到叶节点移动,每个预测样本被赋予一个叶节点,Y在该节点上的分布也被确定。利用CART进行预测同样需要一个学习样本(训练样本)对CART进行建树和评估,然后利用其进行预测。以下面的数据结构为例:
其中,为属性变量,可以是连续或离散的;为类别变量,当为离散时该模型为分类树,当为有序变量时,模型为回归树。
根据给定的训练样本进行建模的步骤主要有:
CART的原理或细节,相关数据挖掘或机器学习书籍都有阐述,另外,百度了相关博客,个人感觉RaySaint的博客把握了CART的关键因素。详见:
http://underthehood.blog.51cto.com/2531780/564685
R软件完成CART
#1调用rpart包进行CART建模
library(rpart)
#1前列腺癌数据stagec
head(stagec)
progstat = factor(stagec$pgstat, levels = 0:1, labels = c("No", "Prog"))
#2建树,method主要有 "anova", "poisson", "class" "exp"。通常作生存分析选exp,因变量是因子变量选class,作poisson回归选poisson,其他情况通常选择anova;
cfit = rpart(progstat ~ age + eet + g2 + grade + gleason + ploidy,data = stagec, method ='class')
#输出结果
print(cfit)
#作树图
par(mar = rep(0.1, 4))
plot(cfit)
#添加标签
text(cfit)
#对分类结果作混淆矩阵
(temp = with(stagec, table(cut(grade, c(0, 2.5, 4)),
cut(gleason, c(2, 5.5, 10)),exclude = NULL)))
#3剪枝
cfit2=prune(cfit,cp=.02)
plot(cfit2)
text(cfit2)
printcp(cfit2)#输出剪枝表格
summary(cfit2)#输出CART完整细节,包括printcp内容
#4rpart中相关参数,rpart(,..,parms=())
"Anova"分类没有参数
"Poisson"分类只有单一参数:率的先验分布的变异系数,默认为1
"Exp"分类参数同poisson
"Class"分类包含的参数最为复杂,包括先验概率、损失矩阵或分类指标(Gini或Information)。#4.1比较Gini和Information分类指标,以自带汽车消费数据为例cu.summary
head(cu.summary)#查阅数据
fit1 = rpart(Reliability ~ Price + Country + Mileage + Type, data = cu.summary, parms = list(split = 'gini'))
fit2 = rpart(Reliability ~ Price + Country + Mileage + Type,data = cu.summary, parms = list(split = 'information'))
par(mfrow = c(1,2), mar = rep(0.1, 4))
plot(fit1, margin = 0.05); text(fit1, use.n = TRUE, cex = 0.8)
plot(fit2, margin = 0.05); text(fit2, use.n = TRUE, cex = 0.8)
#4.2比较parms中的先验概率(prior)和损失矩阵(loss)参数,以rpart自带驼背数据kyphosis为例
#查阅数据
head(kyphosis)
#默认的先验概率为Kyphosis两类的频率比fit1 = rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)#定义先验概率prior=c(..,..)fit2 = rpart(Kyphosis ~ Age + Number + Start, data = kyphosis, parms = list(prior = c(0.65, 0.35)))
##loss参数设置,首先一个损失矩阵lmat
lmat = matrix(c(0,3, 4,0), nrow = 2, ncol = 2, byrow = FALSE)fit3 = rpart(Kyphosis ~ Age + Number + Start, data = kyphosis,parms = list(loss = lmat))par(mfrow = c(1, 3), mar = rep(0.1, 4))plot(fit1); text(fit1, use.n = TRUE, all = TRUE, cex = 0.8)plot(fit2); text(fit2, use.n = TRUE, all = TRUE, cex = 0.8)plot(fit3); text(fit3, use.n = TRUE, all = TRUE, cex = 0.8)
二、回归树
1.通常默认anova用来作回归树,以汽车消费数据car90为例,该数据包括34个变量110条观察值。
#查阅car90数据
head(car90);str(car90)
#剔除轮胎尺寸、型号等3个因素型变量(factor variable):"Rim", "Tires", "Model2"
cars = car90[, -match(c("Rim", "Tires", "Model2"), names(car90))]#建立回归树模型carfit = rpart(Price/1000 ~ ., data=cars)carfit;printcp(carfit);summary(carfit,cp=0.1)plot(carfit);text(carfit)
#图示不同分类的误差,par(mfrow=c(1,2)); rsq.rpart(carfit)
2.Poisson回归树
以数据solder为例
#查看数据,变量属性
head(solder);str(solder)
#建立poisson回归树
sfit = rpart(skips ~ Opening + Solder + Mask + PadType + Panel,data = solder, method = 'poisson',control = rpart.control(cp = 0.05, maxcompete = 2))sfit;printcp(sfit);summary(sfit,cp=.1)
3.生存模型回归树
#以前列腺癌数据stagec为例,调用survival包进行生存分析
library(survival)temp = coxph(Surv(pgtime, pgstat) ~ 1, stagec)newtime = predict(temp, type = 'expected')
pfit <- rpart(Surv(pgtime, pgstat) ~ age + eet + g2 + grade +gleason + ploidy, data = stagec)
pfit2 <- prune(pfit, cp = 0.016)#进行减枝
par(mar = rep(0.2, 4))
plot(pfit2, uniform = TRUE, branch = 0.4, compress = TRUE)
text(pfit2, use.n = TRUE)
数据分析咨询请扫描二维码