Rでニューラルネットワークの可視化

2018年11月14日

Rにおけるニューラルネットワーク可視化パッケージ

neuralnet関数は標準でplot()関数によりで計算グラフを可視化できる。neuralnet関数のような機能のない他のニューラルネットワークパッケージを使用した時の計算グラフの可視化方法を以下にメモ。

  • plot.nn関数
  • plotnet関数

下準備

サンプルデータはirisを使用

学習器の作成

d=iris
d$Species <- as.factor(d$Species)

#train_test_split
set.seed(0)
sample <- sample.int(n = nrow(d), size = floor(0.80*nrow(d)), replace = F)
train <- d[sample, ]
test <- d[-sample, ]
summary(train)

#nnet
library(nnet)
nn1=nnet(Species~., size=5, data=train)
pred_nn1 <- predict(nn1, test,type="class")
table(test$Species,pred_nn1)

nnetを可視化

いずれも色で正負を、太さで数値の大小を表してます。

plot.nn関数

source("http://hosho.ees.hokudai.ac.jp/~kubo/log/2007/img07/plot.nn.txt")
plot.nn(nn1)

 


plot.nnet関数

install.packages("NeuralNetTools")
library(NeuralNetTools)
plotnet(nn1)

 

 

ちなみにneuralnet関数では

library(caret) 
tmp <- dummyVars(~.,data=train) 
dummy <- as.data.frame(predict(tmp, train))
library("neuralnet") 
f = Species.setosa + Species.versicolor + Species.virginica~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width 
nn2 <- neuralnet(formula = f, data = dummy)<br>plot(nn2)

 

変数が多くなる場合は横向きのplotnet関数による可視化のほうが見やすいですね。plotnet関数は、nnetだけでなくRSNNSやcaretで作成したニューラルネットワークも可視化できて適応範囲が広いので使いやすいです。