登录
首页精彩阅读R语言利用nnet包训练神经网络模型
R语言利用nnet包训练神经网络模型
2018-05-23
收藏

R语言利用nnet包训练神经网络模型

R语言提供了另外一个能够处理人工神经网络的算法包nnet,该算法提供了传统的前馈反向传播神经网络算法的实现。
操作
安装包与数据分类:
library(nnet)
data("iris")
set.seed(2)
ind = sample(2,nrow(iris),replace = TRUE,prob = c(0.7,0.3))
trainset = iris[ind == 1,]
testset = iris[ind == 2,]
使用nnet包训练神经网络

iris.nn = nnet(Species ~ .,data = trainset,size = 2,rang = 0.1,decay = 5e-4,maxit = 200)
# weights:  19
initial  value 114.539765
iter  10 value 52.100312
iter  20 value 50.231442
iter  30 value 49.526599
iter  40 value 49.402229
iter  50 value 44.680338
iter  60 value 5.254389
iter  70 value 2.836695
iter  80 value 2.744315
iter  90 value 2.687069
iter 100 value 2.621556
iter 110 value 2.589096
iter 120 value 2.410539
iter 130 value 2.096461
iter 140 value 1.938717
iter 150 value 1.857105
iter 160 value 1.825393
iter 170 value 1.817409
iter 180 value 1.815591
iter 190 value 1.815030
iter 200 value 1.814746
final  value 1.814746
stoppedafter 200 iterations

调用summary( )输出训练好的神经网络

summary(iris.nn)
a 4-2-3 network with 19 weights
options were - softmax modelling  decay=5e-04
 b->h1 i1->h1 i2->h1 i3->h1 i4->h1
-20.60   0.31  -3.84   3.36   7.72
 b->h2 i1->h2 i2->h2 i3->h2 i4->h2
 -7.15   1.50   2.49  -4.14   5.59
 b->o1 h1->o1 h2->o1
 -7.28  -3.67  13.16
 b->o2 h1->o2 h2->o2
 15.90 -16.64 -19.40
 b->o3 h1->o3 h2->o3
 -8.62  20.31   6.24
在应用函数时可以实现分类观测,数据源,隐蔽单元个数(size参数),初始随机数权值(rang参数),权值衰减参数(decay参数),最大迭代次数(maxit),整个过程会一直重复直至拟合准则值与衰减项收敛。
使用模型iris.nn模型完成对测试数据集的预测

iris.predict = predict(iris.nn,testset,type = "class")
nn.table = table(testset$Species,iris.predict)
nn.table
            iris.predict
             setosa versicolor virginica
  setosa         17          0         0
  versicolor      0         13         1
  virginica       0          2        13

基于分类表得到混淆矩阵

confusionMatrix(nn.table)
Confusion Matrix and Statistics

            iris.predict
             setosa versicolor virginica
  setosa         17          0         0
  versicolor      0         13         1
  virginica       0          2        13

Overall Statistics

               Accuracy : 0.9348         
                 95% CI : (0.821, 0.9863)
    No Information Rate : 0.3696         
    P-Value [Acc > NIR] : 1.019e-15      

                  Kappa : 0.9019         
 Mcnemar's Test P-Value : NA             

Statistics by Class:

                     Class: setosa Class: versicolor Class: virginica
Sensitivity                 1.0000            0.8667           0.9286
Specificity                 1.0000            0.9677           0.9375
Pos Pred Value              1.0000            0.9286           0.8667
Neg Pred Value              1.0000            0.9375           0.9677
Prevalence                  0.3696            0.3261           0.3043
Detection Rate              0.3696            0.2826           0.2826
Detection Prevalence        0.3696            0.3043           0.3261
Balanced Accuracy           1.0000            0.9172           0.9330
在调用predict函数时,我们明确了type参数为class,因此输出的是预测的类标号而非概率矩阵。接下来调用table函数根据预测结果和testset的实际类标号生成分类表,最后利用建立的分类表使用table函数根据caret中的confusionMatrix方法对训练好的神经网络预测性能评估。

数据分析咨询请扫描二维码

客服在线
立即咨询