FC2カウンター FPGAの部屋 TensorFlow, Keras

FPGAやCPLDの話題やFPGA用のツールの話題などです。 マニアックです。 日記も書きます。

FPGAの部屋

FPGAの部屋の有用と思われるコンテンツのまとめサイトを作りました。Xilinx ISEの初心者の方には、FPGAリテラシーおよびチュートリアルのページをお勧めいたします。

TensorFlow + Kerasを使ってみた8(全結合層の統計情報)

TensorFlow + Kerasを使ってみた7(畳み込み層の統計情報)”の続き。

前回は、畳み込み層の最大値、最小値、絶対値の最大値、最小値、標準偏差などの統計情報を取得した。今回は、全結合層の統計情報を取得しよう。
なお、使用するのは MNIST の手書き数字を認識するCNN で畳み込み層の特徴マップの数は 10 個となっている。

全結合層1層目
畳み込み層同様に、model.summary() で取得した各層の情報を元に全結合層1層目の中間出力を取り出す。

# 1番目のDence layer1の中間出力を取り出す 
from keras.models import Model

dence_layer1_name = 'dense_5'

dence_layer1 = model.get_layer(dence_layer1_name)
dence_layer1_wb = dence_layer1.get_weights()

dence_layer1_model = Model(inputs=model.input,
                                 outputs=model.get_layer(dence_layer1_name).output)
dence_layer1_output = dence_layer1_model.predict(x_test, verbose=1)


表示を示す。

10000/10000 [==============================] - 2s 174us/step


重みとバイアスの配列の形状を取得した。

print(dence_layer1_weight.shape)
print(dence_layer1_bias.shape)

(1440, 100)
(100,)


全結合層1層目の重みとバイアスの最大値、最小値、絶対値の最大値、最小値を取得しよう。

print("np.max(dence_layer1_weight) = {0}".format(np.max(dence_layer1_weight)))
print("np.min(dence_layer1_weight) = {0}".format(np.min(dence_layer1_weight)))
abs_dence_layer1_weight = np.absolute(dence_layer1_weight)
print("np.max(abs_dence_layer1_weight) = {0}".format(np.max(abs_dence_layer1_weight)))
print("np.min(abs_dence_layer1_weight) = {0}".format(np.min(abs_dence_layer1_weight)))
print("np.max(dence_layer1_bias) = {0}".format(np.max(dence_layer1_bias)))
print("np.min(dence_layer1_bias) = {0}".format(np.min(dence_layer1_bias)))
abs_dence_layer1_bias = np.absolute(dence_layer1_bias)
print("np.max(abs_dence_layer1_bias) = {0}".format(np.max(abs_dence_layer1_bias)))
print("np.min(abs_dence_layer1_bias) = {0}".format(np.min(abs_dence_layer1_bias)))

np.max(dence_layer1_weight) = 0.287210673094
np.min(dence_layer1_weight) = -0.320384502411
np.max(abs_dence_layer1_weight) = 0.320384502411
np.min(abs_dence_layer1_weight) = 1.92374045582e-07
np.max(dence_layer1_bias) = 0.105059452355
np.min(dence_layer1_bias) = -0.0615252479911
np.max(abs_dence_layer1_bias) = 0.105059452355
np.min(abs_dence_layer1_bias) = 0.000534977589268


全結合層1層目の出力の標準偏差、最大値、最小値、絶対値の最大値、最小値を取得しよう。

print("dence_layer1_output = {0}".format(dence_layer1_output.shape))
print("np.std(dence_layer1_output) = {0}".format(np.std(dence_layer1_output)))
print("np.max(dence_layer1_output) = {0}".format(np.max(dence_layer1_output)))
print("np.min(dence_layer1_output) = {0}".format(np.min(dence_layer1_output)))
abs_dence_layer1_output = np.absolute(dence_layer1_output)
print("np.max(abs_dence_layer1_output) = {0}".format(np.max(abs_dence_layer1_output)))
print("np.min(abs_dence_layer1_output) = {0}".format(np.min(abs_dence_layer1_output)))

dence_layer1_output = (10000, 100)
np.std(dence_layer1_output) = 3.02382802963
np.max(dence_layer1_output) = 14.2637271881
np.min(dence_layer1_output) = -13.8859920502
np.max(abs_dence_layer1_output) = 14.2637271881
np.min(abs_dence_layer1_output) = 4.04380261898e-06


全結合層1層目の重みのグラフを示す。

# Dence layer1のweightのグラフ
dence_layer1_weight_f = dence_layer1_weight.flatten()
plt.plot(dence_layer1_weight_f)
plt.title('dence_layer1_weight')
plt.show()


tensorflow_keras_47_180530.png

全結合層1層目のバイアスのグラフを示す。

# Dence layer1のbiasのグラフ
dence_layer1_bias_f = dence_layer1_bias.flatten()
plt.plot(dence_layer1_bias_f)
plt.title('dence_layer1_bias')
plt.show()


tensorflow_keras_48_180530.png


全結合層2層目
model.summary() で取得した各層の情報を元に全結合層2層目の中間出力を取り出す。

# 2番目のDence layer2の中間出力を取り出す 
from keras.models import Model

dence_layer2_name = 'dense_6'

dence_layer2 = model.get_layer(dence_layer2_name)
dence_layer2_wb = dence_layer2.get_weights()

dence_layer2_model = Model(inputs=model.input,
                                 outputs=model.get_layer(dence_layer2_name).output)
dence_layer2_output = dence_layer2_model.predict(x_test, verbose=1)


表示を示す。

10000/10000 [==============================] - 2s 167us/step


重みとバイアスの配列の形状を取得した。

print(dence_layer2_weight.shape)
print(dence_layer2_bias.shape)

(100, 10)
(10,)


全結合層2層目の重みとバイアスの最大値、最小値、絶対値の最大値、最小値を取得しよう。

print("np.max(dence_layer2_weight) = {0}".format(np.max(dence_layer2_weight)))
print("np.min(dence_layer2_weight) = {0}".format(np.min(dence_layer2_weight)))
abs_dence_layer2_weight = np.absolute(dence_layer2_weight)
print("np.max(abs_dence_layer2_weight) = {0}".format(np.max(abs_dence_layer2_weight)))
print("np.min(abs_dence_layer2_weight) = {0}".format(np.min(abs_dence_layer2_weight)))
print("np.max(dence_layer2_bias) = {0}".format(np.max(dence_layer2_bias)))
print("np.min(dence_layer2_bias) = {0}".format(np.min(dence_layer2_bias)))
abs_dence_layer2_bias = np.absolute(dence_layer2_bias)
print("np.max(abs_dence_layer2_bias) = {0}".format(np.max(abs_dence_layer2_bias)))
print("np.min(abs_dence_layer2_bias) = {0}".format(np.min(abs_dence_layer2_bias)))

np.max(dence_layer2_weight) = 0.420090407133
np.min(dence_layer2_weight) = -0.625470399857
np.max(abs_dence_layer2_weight) = 0.625470399857
np.min(abs_dence_layer2_weight) = 0.000126185041154
np.max(dence_layer2_bias) = 0.0749695450068
np.min(dence_layer2_bias) = -0.0558836981654
np.max(abs_dence_layer2_bias) = 0.0749695450068
np.min(abs_dence_layer2_bias) = 0.00171886803582


全結合層2層目の出力の標準偏差、最大値、最小値、絶対値の最大値、最小値を取得しよう。

print("dence_layer2_output = {0}".format(dence_layer2_output.shape))
print("np.std(dence_layer2_output) = {0}".format(np.std(dence_layer2_output)))
print("np.max(dence_layer2_output) = {0}".format(np.max(dence_layer2_output)))
print("np.min(dence_layer2_output) = {0}".format(np.min(dence_layer2_output)))
abs_dence_layer2_output = np.absolute(dence_layer2_output)
print("np.max(abs_dence_layer2_output) = {0}".format(np.max(abs_dence_layer2_output)))
print("np.min(abs_dence_layer2_output) = {0}".format(np.min(abs_dence_layer2_output)))



dence_layer2_output = (10000, 10)
np.std(dence_layer2_output) = 9.34499263763
np.max(dence_layer2_output) = 30.0013465881
np.min(dence_layer2_output) = -35.2990074158
np.max(abs_dence_layer2_output) = 35.2990074158
np.min(abs_dence_layer2_output) = 0.000138353556395

全結合層2層目の重みのグラフを示す。

# Dence layer2のweightのグラフ
dence_layer2_weight_f = dence_layer2_weight.flatten()
plt.plot(dence_layer2_weight_f)
plt.title('dence_layer2_weight')
plt.show()


tensorflow_keras_49_180530.png

全結合層2層目のバイアスのグラフを示す。

# Dence layer2のbiasのグラフ
dence_layer2_bias_f = dence_layer2_bias.flatten()
plt.plot(dence_layer2_bias_f)
plt.title('dence_layer2_bias')
plt.show()


tensorflow_keras_50_180530.png
  1. 2018年05月30日 04:26 |
  2. TensorFlow, Keras
  3. | トラックバック:0
  4. | コメント:0

TensorFlow + Kerasを使ってみた7(畳み込み層の統計情報)

TensorFlow + Kerasを使ってみた6(層構成の変更、畳み込み層の統計情報)”の続き。

前回は、層の構成を変更して、学習を行った。今回は、畳み込み層の統計情報を取得してみよう。
なお、使用するのは MNIST の手書き数字を認識するCNN で畳み込み層の特徴マップの数は 10 個となっている。

前回、model.summary() で取得した各層の情報を元に畳み込み層の中間出力を取り出そう。

# Convolution layerの中間出力を取り出す 
from keras.models import Model

conv_layer_name = 'conv2d_4'

conv_layer = model.get_layer(conv_layer_name)
conv_layer_wb = conv_layer.get_weights()

conv_layer_model = Model(inputs=model.input,
                                 outputs=model.get_layer(conv_layer_name).output)
conv_output = conv_layer_model.predict(x_test, verbose=1)


表示を示す。

10000/10000 [==============================] - 1s 150us/step


最初に重みやバイアスの配列の構成やその値を見てみよう。

conv_layer_weight = conv_layer_wb[0]
conv_layer_bias = conv_layer_wb[1]

print(conv_layer_weight.shape)
print(conv_layer_weight.T.shape)
print(conv_layer_bias.shape)


結果を示す。

(5, 5, 1, 10)
(10, 1, 5, 5)
(10,)


conv_layer_weight は (5, 5, 1, 10) なので、いつもあつかている配列の構成ではない。そして、conv_layer_weight.T で転置してみたところ、いつも使っている構成の (10, 1, 5, 5) つまり、カーネル数、チャネル数、縦の幅、横の幅の配列になった。

print("conv_layer_weight.T = {0}".format(conv_layer_weight.T))

で転置した重みの配列を示す。

conv_layer_weight.T = [[[[ 0.20261094 -0.3398506  -0.5767307   0.11835691  0.13021287]
   [-0.07934965 -0.33518496 -0.4275438   0.25123549  0.38388866]
   [-0.22467291 -0.39792794 -0.07211141  0.38731813  0.24981308]
   [-0.43532223 -0.08618319  0.3187846   0.27912328  0.02272184]
   [-0.23157348  0.16632372  0.2654636   0.15583257 -0.04710154]]]


 [[[-0.48611653  0.15439186  0.4068115   0.3514016   0.16548221]
   [-0.6372757  -0.34480083  0.3836496   0.3169199   0.2640638 ]
   [-0.40548003 -0.56099683  0.10779987  0.32510042  0.3619229 ]
   [-0.33244497 -0.3915109  -0.1230321   0.2981098   0.35238296]
   [-0.13435255 -0.41839477 -0.4722871  -0.10132303  0.11304493]]]


 [[[-0.13130069  0.06223634  0.10510171  0.02183475 -0.16628554]
   [ 0.18190795  0.35684425  0.25642243  0.00863578  0.12985978]
   [ 0.15537558 -0.11242905 -0.2288756   0.04026176  0.08550146]
   [-0.1676253  -0.44136783 -0.29937005 -0.0171281   0.2620432 ]
   [-0.14785497 -0.10125857  0.12721944  0.05586093  0.10579447]]]


 [[[ 0.15156339 -0.20048767 -0.5791418  -0.65549827 -0.25779864]
   [ 0.4940948   0.3314954  -0.1274401  -0.3982863  -0.3313806 ]
   [ 0.22510909  0.45060343  0.15244117 -0.23712645 -0.02554286]
   [ 0.19534814  0.11640821  0.2987521  -0.04862794  0.04132852]
   [ 0.24242142  0.0540004  -0.00865097 -0.0300091   0.12885101]]]


 [[[ 0.36471996  0.35694337  0.27650365  0.35590482  0.13169082]
   [ 0.11910628  0.07778469  0.19447733  0.06036808 -0.12147922]
   [ 0.0868587   0.1454417   0.02258768 -0.2499182  -0.19614659]
   [-0.26476386 -0.27914035 -0.4387378  -0.33735904 -0.03323634]
   [-0.40310845 -0.43084973 -0.27778476 -0.2462857   0.04993651]]]


 [[[-0.08484216  0.19511358  0.58113253 -0.12703945 -0.516542  ]
   [ 0.17010233  0.2240115   0.23622094 -0.31102535 -0.59745365]
   [ 0.397805    0.23805015 -0.1035163  -0.45656392 -0.34286296]
   [ 0.18052064 -0.28208354 -0.29351595 -0.36484626  0.06465741]
   [-0.20084426 -0.30468363 -0.2777929   0.08292956 -0.01636941]]]


 [[[-0.05747546  0.10129268  0.0927546   0.01556351 -0.16821466]
   [-0.0250085   0.140934    0.12933072  0.19052765  0.20077062]
   [ 0.2489682   0.18465307  0.23520534  0.26735055  0.24849436]
   [ 0.00098434 -0.29655868 -0.13283624 -0.11904856 -0.02703394]
   [-0.31173185 -0.3589846  -0.2216169   0.05286852 -0.00669706]]]


 [[[-0.46006405 -0.41662437 -0.26404095 -0.27005908  0.00341533]
   [-0.07625411 -0.01859824 -0.0235228   0.0303653   0.10755768]
   [ 0.07276727  0.20107509  0.15815544  0.3283318   0.23039222]
   [ 0.21414295  0.14830865  0.24796312  0.01516124 -0.05039264]
   [ 0.01465091  0.08253051 -0.08803863  0.01456806 -0.17668988]]]


 [[[ 0.09082198  0.38919494  0.33294797 -0.5168951  -0.62100536]
   [-0.03360464  0.21474971  0.37199846 -0.29824486 -0.6191712 ]
   [ 0.11340497  0.20264329  0.37084493 -0.32331055 -0.5669018 ]
   [ 0.20493641  0.2751836   0.10829608 -0.20219678 -0.39315876]
   [ 0.14139216  0.20002662  0.17661056 -0.22110288 -0.28045934]]]


 [[[ 0.12587526 -0.01364575 -0.2322505  -0.14462651 -0.03129309]
   [ 0.171594    0.22744659 -0.05975187 -0.18951881 -0.2751198 ]
   [ 0.19060391  0.12572204  0.3344037   0.26089588 -0.12050828]
   [-0.17589498 -0.02884873  0.20712087  0.19588387  0.0149854 ]
   [-0.2094244  -0.37416157 -0.08472645  0.12522626  0.06411268]]]]


print("conv_layer_bias = {0}".format(conv_layer_bias))

で表示したバイアス値を示す。

conv_layer_bias = [-0.00722218 0.00386539 -0.10034832 -0.10226133 -0.00783706 -0.00266487
-0.15441592 -0.17244887 0.00067333 -0.17412803]


次に、取得した畳み込み層の中間出力を解析して重みやバイアスの最大値、最小値、絶対値の最大値、最小値、出力の標準偏差、最大値、最小値、絶対値の最大値、最小値を見てみよう。
Python コードを示す。

print("np.max(conv_layer_weight) = {0}".format(np.max(conv_layer_weight)))
print("np.min(conv_layer_weight) = {0}".format(np.min(conv_layer_weight)))
abs_conv_layer_weight = np.absolute(conv_layer_weight)
print("np.max(abs_conv_layer_weight) = {0}".format(np.max(abs_conv_layer_weight)))
print("np.min(abs_conv_layer_weight) = {0}".format(np.min(abs_conv_layer_weight)))

print("np.max(conv_layer_bias) = {0}".format(np.max(conv_layer_bias)))
print("np.min(conv_layer_bias) = {0}".format(np.min(conv_layer_bias)))
abs_conv_layer_bias = np.absolute(conv_layer_bias)
print("np.max(abs_conv_layer_bias) = {0}".format(np.max(abs_conv_layer_bias)))
print("np.min(abs_conv_layer_bias) = {0}".format(np.min(abs_conv_layer_bias)))

print("conv_output = {0}".format(conv_output.shape))
print("np.std(conv_output) = {0}".format(np.std(conv_output)))
print("np.max(conv_output) = {0}".format(np.max(conv_output)))
print("np.min(conv_output) = {0}".format(np.min(conv_output)))

abs_conv_output = np.absolute(conv_output)
print("np.max(abs_conv) = {0}".format(np.max(abs_conv_output)))
print("np.min(abs_conv) = {0}".format(np.min(abs_conv_output)))


出力を示す。

np.max(conv_layer_weight) = 0.581132531166
np.min(conv_layer_weight) = -0.65549826622
np.max(abs_conv_layer_weight) = 0.65549826622
np.min(abs_conv_layer_weight) = 0.000984335667454
np.max(conv_layer_bias) = 0.0038653917145
np.min(conv_layer_bias) = -0.17412802577
np.max(abs_conv_layer_bias) = 0.17412802577
np.min(abs_conv_layer_bias) = 0.000673334288877
conv_output = (10000, 24, 24, 10)
np.std(conv_output) = 0.691880404949
np.max(conv_output) = 3.46592283249
np.min(conv_output) = -4.23804473877
np.max(abs_conv) = 4.23804473877
np.min(abs_conv) = 7.68341124058e-09


畳み込み層の重みのグラフを書いてみた。

# Convolution layerのweightのグラフ
conv_layer_weight_f = conv_layer_weight.flatten()
plt.plot(conv_layer_weight_f)
plt.title('conv_layer_weight')
plt.show()


tensorflow_keras_45_180529.png

畳み込み層のバイアスのグラフを書いてみた。

# Convolution layerのbiasのグラフ
conv_layer_bias_f = conv_layer_bias.flatten()
plt.plot(conv_layer_bias_f)
plt.title('conv_layer_bias')
plt.show()


tensorflow_keras_46_180529.png
  1. 2018年05月29日 04:39 |
  2. TensorFlow, Keras
  3. | トラックバック:0
  4. | コメント:0

TensorFlow + Kerasを使ってみた6(層構成の変更、学習)

TensorFlow + Kerasを使ってみた5(モデルの可視化、サーマリ)”の続き。

前回は、モデルの可視化で層の構成情報のPNGファイルを出力し、レイヤのサーマリを出力した。今回は、層の構成を変更して、畳み込み層の学習を行った。
なお、使用するのは MNIST の手書き数字を認識するCNN で畳み込み層の特徴マップの数は 10 個となっている。

なぜ、層構成を変更するか?というと、今のままでは、例えば畳み込み層に activation として relu が組み込まれていたが、そうすると層の出力を取り出しても 0 以下が取り除かれている状態になっているからだ。飽和演算をすれば問題ないのだが、量子化ビット幅のマイナス側が設定できないので、回ってしまってマイナスがプラスになってしまう可能性がある。
そこで、畳み込み層は activation を削除して、ReLU は新たに Activation 層を追加することにした。これは 全結合層(dence)も同様だ。
新しいネットワークの定義を示す。
なお、Jupyter Notebook のノートの名前は keras_mnist_cnn10 だ。

# My Mnist CNN (Convolution layerの特徴マップは5個)
# Conv2D - ReLU - MaxPooling - Dence - ReLU - Dence
# 2018/05/25 by marsee
# Keras / Tensorflowで始めるディープラーニング入門 https://qiita.com/yampy/items/706d44417c433e68db0d
# のPythonコードを再利用させて頂いている

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D, Activation
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

img_rows, img_cols = 28, 28

(x_train, y_train), (x_test, y_test) = mnist.load_data()

#Kerasのバックエンドで動くTensorFlowとTheanoでは入力チャンネルの順番が違うので場合分けして書いています
if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

y_train = y_train.astype('int32')
y_test = y_test.astype('int32')
y_train = keras.utils.np_utils.to_categorical(y_train, num_classes)
y_test =  keras.utils.np_utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(10, kernel_size=(5, 5),
                 input_shape=input_shape))
model.add(Activation(activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(100))
model.add(Activation(activation='relu'))
model.add(Dense(num_classes))
model.add(Activation(activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,
          verbose=1, validation_data=(x_test, y_test))


上のPython コードを実行した結果を示す。

('x_train shape:', (60000, 28, 28, 1))
(60000, 'train samples')
(10000, 'test samples')
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
60000/60000 [==============================] - 12s 201us/step - loss: 0.2579 - acc: 0.9231 - val_loss: 0.0840 - val_acc: 0.9733
Epoch 2/12
60000/60000 [==============================] - 12s 201us/step - loss: 0.0785 - acc: 0.9762 - val_loss: 0.0564 - val_acc: 0.9819
Epoch 3/12
60000/60000 [==============================] - 12s 192us/step - loss: 0.0545 - acc: 0.9834 - val_loss: 0.0492 - val_acc: 0.9838
Epoch 4/12
60000/60000 [==============================] - 13s 210us/step - loss: 0.0425 - acc: 0.9869 - val_loss: 0.0442 - val_acc: 0.9862
Epoch 5/12
60000/60000 [==============================] - 12s 196us/step - loss: 0.0340 - acc: 0.9898 - val_loss: 0.0396 - val_acc: 0.9875
Epoch 6/12
60000/60000 [==============================] - 12s 198us/step - loss: 0.0284 - acc: 0.9915 - val_loss: 0.0382 - val_acc: 0.9874
Epoch 7/12
60000/60000 [==============================] - 11s 191us/step - loss: 0.0243 - acc: 0.9928 - val_loss: 0.0340 - val_acc: 0.9886
Epoch 8/12
60000/60000 [==============================] - 11s 189us/step - loss: 0.0206 - acc: 0.9937 - val_loss: 0.0371 - val_acc: 0.9878
Epoch 9/12
60000/60000 [==============================] - 12s 199us/step - loss: 0.0167 - acc: 0.9949 - val_loss: 0.0312 - val_acc: 0.9897
Epoch 10/12
60000/60000 [==============================] - 12s 195us/step - loss: 0.0146 - acc: 0.9954 - val_loss: 0.0317 - val_acc: 0.9896
Epoch 11/12
60000/60000 [==============================] - 11s 188us/step - loss: 0.0121 - acc: 0.9963 - val_loss: 0.0344 - val_acc: 0.9892
Epoch 12/12
60000/60000 [==============================] - 12s 205us/step - loss: 0.0103 - acc: 0.9970 - val_loss: 0.0320 - val_acc: 0.9898


精度と損失のグラフを示す。
tensorflow_keras_43_180528.png

model.summary() の結果を示す。
tensorflow_keras_44_180528.png

畳み込み層の conv2d_4 の統計情報、つまり、重みの最大値、最小値、バイアスの最大値、最小値、畳み込み層の最大値、最小値を見ていこう。
次回に続く。。。
  1. 2018年05月28日 04:56 |
  2. TensorFlow, Keras
  3. | トラックバック:0
  4. | コメント:0

TensorFlow + Kerasを使ってみた5(モデルの可視化、サーマリ)

TensorFlow + Kerasを使ってみた4(modelの重みの表示)”の続き。

前回はモデルの重みの表示をしてみた。今回は、モデルを可視化したり、サーマリを表示してみた。
なお、使用するのは MNIST の手書き数字を認識するCNN で畳み込み層の特徴マップの数は 10 個となっている。

今回、参考にさせて頂いたのは、”[keras]中間層の出力”だ。

最初にモデルの可視化ということで、モデルをPNG ファイルに出力する。

from keras.utils.vis_utils import plot_model
plot_model(model, show_shapes=True, to_file='./model.png')


このPython コードを実行すると pydot が無いというエラーになった。
tensorflow_keras_32_180525.png

sudo pip install pydot
でインストールした。
tensorflow_keras_33_180525.png

もう一度、Python コードを実行すると、今度は、GraphViz が無いというエラーだった。
tensorflow_keras_34_180526.png

sudo apt install graphviz
でインストールした。
tensorflow_keras_35_180526.png

3回めの実行で、 ~/DNN/Keras に model.png が出来た。
tensorflow_keras_38_180526.png

model.png を開いてみた。
tensorflow_keras_39_180526.png

次に、モデルのサーマリを表示した。

model.summary()


表示されたサーマリを示す。

_______________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_5 (Conv2D)            (None, 24, 24, 10)        260       
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 12, 12, 10)        0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 1440)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 100)               144100    
_________________________________________________________________
dense_8 (Dense)              (None, 10)                1010      
=================================================================
Total params: 145,370
Trainable params: 145,370
Non-trainable params: 0


最後にスクリーンショットを貼っておく。
tensorflow_keras_42_180527.png
  1. 2018年05月27日 08:37 |
  2. TensorFlow, Keras
  3. | トラックバック:0
  4. | コメント:0

TensorFlow + Kerasを使ってみた4(modelの重みの表示)

TensorFlow + Kerasを使ってみた3(以前使用したCNNを使った学習)”の続き。

今回参照させて頂いたのは以下の記事だ。
Keras Documentation FAQ
KerasでCNNを簡単に構築
[keras]中間層の出力

とりあえずの私の CNN の実装としては層ごとの出力の値域、重みの変域、量子化された推論が必要なので、それらを TensorFlow + Keras でできるのかを探っていきたい。

使用しているのは、”TensorFlow + Kerasを使ってみた3(以前使用したCNNを使った学習)”のMNIST の畳み込み層 10 層の CNN だ。

まずは、モデル・パラメータの保存と読み込みをしたいと思う。それは、”Keras Documentation FAQ の Keras modelを保存するには?”に書いてあった。そのコードを示す。

# 学習済みモデルの保存

from keras.models import load_model

model.save('mnist_cnn10_model.h5') # creates a HDF5 file 'my_model.h5'


del model # deletes the existing model


# 学習済みモデルの読み込み

from keras.models import load_model

model = load_model('mnist_cnn10_model.h5')


~/DNN/Keras に mnist_cnn10_model.h5 ファイルができていた。
tensorflow_keras_37_180526.png

HDFView 2.9 で、mnist_cnn10_model.h5 ファイルを見てみたが、conv2d_5 を見てみたが 10 個カーネルがあるはずが 1 個しか見えないので、おかしい?
tensorflow_keras_40_180526.png

おかしいので、モデルの重みを表示してみた。

model_list = model.get_weights()
print model_list


表示された重みを示す。

[array([[[[-0.4802782 ,  0.08903456,  0.17129059, -0.18762073,
          -0.19215566,  0.0767612 , -0.19989973,  0.16069482,
           0.10121205,  0.10839224]],

        [[-0.55260026,  0.31618226,  0.02519642, -0.3371241 ,
          -0.18700214,  0.47713503,  0.01407363,  0.18270946,
          -0.00360232,  0.0480496 ]],

        [[-0.28096938,  0.31544265,  0.21408693, -0.4249214 ,
           0.05259206,  0.24198672, -0.12785499, -0.16038668,
           0.25517157,  0.02352966]],

        [[ 0.19760679,  0.17439696,  0.15509322, -0.3724036 ,
           0.14294085,  0.30212507,  0.20030482, -0.0445758 ,
           0.16999234,  0.11382752]],

        [[ 0.34932545,  0.19552206,  0.3587776 , -0.50719273,
           0.2470015 ,  0.07441435,  0.22428715, -0.35767055,
          -0.07781951,  0.22610919]]],


       [[[-0.50595784,  0.20634648,  0.32507014,  0.05827161,
          -0.15143315, -0.48752806,  0.10217368,  0.00805497,
           0.0631953 , -0.02518882]],

        [[-0.18465312,  0.30423662,  0.39945245,  0.01863884,
           0.03493409, -0.5750042 ,  0.02879223,  0.03981001,
           0.0992638 , -0.04431673]],

        [[-0.01881977,  0.28165781,  0.27969372, -0.12105816,
           0.39619493, -0.20682125,  0.11560314,  0.15059802,
           0.31867403, -0.01151863]],

        [[ 0.34246606,  0.23152475,  0.11814857, -0.16201593,
           0.33553603, -0.08004267,  0.05377955,  0.05844887,
           0.3824376 ,  0.29480523]],

        [[ 0.22607948,  0.1431006 ,  0.12278882, -0.06739531,
          -0.10244448, -0.16496062,  0.07855147, -0.0039023 ,
          -0.01187495,  0.31914127]]],


       [[[-0.36106065, -0.1614712 ,  0.0425666 ,  0.26438203,
           0.1530066 , -0.27296525,  0.13756014,  0.12951061,
          -0.13317643, -0.27455592]],

        [[-0.24096149,  0.00605726,  0.1439901 ,  0.2999792 ,
           0.3998839 , -0.57802194,  0.20733553,  0.22792007,
          -0.11590697, -0.10110019]],

        [[ 0.19279055, -0.05433302,  0.1588538 ,  0.36550105,
           0.3963272 , -0.40570349,  0.07748813, -0.03811004,
           0.27451742, -0.20028938]],

        [[ 0.3532763 , -0.00738094,  0.07176854,  0.44040167,
           0.05400598, -0.34983876,  0.0413314 ,  0.1358357 ,
           0.4606149 ,  0.22590274]],

        [[ 0.02443966,  0.31850007,  0.0679253 ,  0.44180956,
          -0.33495304, -0.47900125,  0.01341348,  0.11268906,
           0.37477848,  0.21716192]]],


       [[[-0.28646868, -0.33610198, -0.4617771 ,  0.21716239,
           0.13870867,  0.2500289 , -0.00521774,  0.18825355,
          -0.44079527, -0.45571223]],

        [[ 0.10224386, -0.38061497, -0.319509  ,  0.12749907,
           0.24174243,  0.12979732, -0.03148215,  0.25274608,
          -0.31800672, -0.29355517]],

        [[ 0.30091146, -0.28666645, -0.10368297,  0.39559895,
          -0.08464333, -0.20640281,  0.10996511, -0.12132592,
          -0.2605404 , -0.15574814]],

        [[ 0.27224183, -0.08070813, -0.30462918,  0.2425765 ,
          -0.42863956, -0.38760048,  0.01434944,  0.01161201,
           0.11027226,  0.15539089]],

        [[-0.03363509,  0.15163417, -0.39445326,  0.26487276,
          -0.46634   , -0.37421483, -0.25277063, -0.30671078,
           0.15761915,  0.27163678]]],


       [[[ 0.01090175, -0.3971682 , -0.3616919 , -0.2198588 ,
           0.22649288,  0.36216414,  0.16255492,  0.28195596,
          -0.2900382 , -0.3076533 ]],

        [[ 0.02631478, -0.21446598, -0.64792824, -0.12109952,
           0.0477443 ,  0.6136793 ,  0.07133511,  0.03860151,
          -0.32767197, -0.33002102]],

        [[ 0.26029003, -0.17678206, -0.53189945, -0.22726585,
          -0.29128915,  0.50199383,  0.10108512, -0.10461918,
          -0.41573155, -0.39046872]],

        [[ 0.08828537, -0.34382978, -0.27508992, -0.20536248,
          -0.29675165,  0.05918439, -0.04499418, -0.4220725 ,
          -0.12406551,  0.18522236]],

        [[-0.16209605, -0.10475101, -0.17508513, -0.19714086,
          -0.2464418 , -0.06318696, -0.16879626, -0.35019565,
           0.11980721,  0.3179446 ]]]], dtype=float32), array([-0.05326715, -0.10172927, -0.00153971, -0.00202374, -0.11225405,
        0.03784639, -0.22747649, -0.1184488 , -0.13200918, -0.18062808],
      dtype=float32), array([[ 0.05988384,  0.0177884 , -0.01428634, ...,  0.03967512,
        -0.02745369,  0.0235893 ],
       [-0.02437956, -0.02609818,  0.00166744, ..., -0.0283565 ,
        -0.03127908, -0.03267145],
       [ 0.01952924,  0.04912804,  0.05458435, ..., -0.03700528,
         0.02735562,  0.05371138],
       ...,
       [ 0.03930264, -0.05253042, -0.02040245, ...,  0.02623902,
        -0.04602768, -0.0570806 ],
       [ 0.05283869,  0.03285475,  0.0225143 , ...,  0.0078735 ,
         0.06278732, -0.02751559],
       [-0.03585282,  0.05243319, -0.00109443, ..., -0.02352152,
        -0.0505695 , -0.03361446]], dtype=float32), array([-0.005398  , -0.01852978,  0.022526  , -0.05067606, -0.01559338,
        0.00181101, -0.00674882,  0.02209998,  0.0119611 ,  0.02653758,
        0.04715605,  0.00115289,  0.01218627,  0.00955511, -0.00915333,
        0.01971679,  0.00280326, -0.00939319,  0.02062515, -0.03470616,
       -0.01717205, -0.0006302 ,  0.0281504 , -0.02706482, -0.02169199,
       -0.04854991, -0.01171666, -0.03342789, -0.00403417,  0.04069184,
       -0.00713639,  0.01422117, -0.03270031,  0.07872368, -0.01933507,
       -0.03112246, -0.0127308 ,  0.02346958,  0.02234364,  0.0009662 ,
       -0.01526533,  0.02731066, -0.02599644,  0.03074143, -0.02295697,
        0.01084804, -0.00369489, -0.02861736, -0.01144757,  0.01723441,
        0.03169027, -0.04880989, -0.00714325,  0.00121178,  0.02717792,
        0.02516818,  0.04525929,  0.01094827, -0.0134111 , -0.00777571,
       -0.01005205, -0.03453677,  0.02772082, -0.03537939, -0.01389873,
       -0.00258296,  0.0014193 ,  0.02491944,  0.00557612,  0.05277465,
        0.03517715,  0.03964692,  0.04351197, -0.00425854, -0.00757222,
       -0.00217356,  0.01388615,  0.04111845, -0.01360166, -0.02312726,
        0.01533243, -0.01296438, -0.0164743 , -0.01136677,  0.01515818,
        0.02614263, -0.0167461 ,  0.00552223, -0.01624235,  0.00473867,
        0.03577977, -0.01952587,  0.01290206,  0.0558174 ,  0.01492252,
       -0.04054749, -0.01954609, -0.00233423,  0.01031396,  0.03316768],
      dtype=float32), array([[-2.80878514e-01, -2.88952619e-01, -6.83253184e-02,
        -4.65221584e-01, -7.49759227e-02,  6.33983985e-02,
        -4.09977026e-02,  1.52914107e-01, -1.62776962e-01,
         2.25325406e-01],
       [-1.16497621e-01, -4.10411984e-01, -7.08504915e-02,
         1.14576600e-01, -2.84170300e-01,  2.02368587e-01,
         3.22541237e-01,  2.77578473e-01, -2.31056251e-02,
        -3.96834254e-01],
       [ 1.92328736e-01, -2.25020871e-01, -1.11906387e-01,
        -2.87643850e-01, -1.05152600e-01,  5.42706251e-02,
         4.07181308e-03, -4.37325239e-01, -5.72222099e-02,
        -3.27897817e-01],
       [ 4.58031707e-02, -4.31063652e-01, -1.70698479e-01,
         2.36170590e-01,  3.97744030e-02,  2.09598631e-01,
         1.29017875e-01, -2.64844820e-02,  1.15363739e-01,
         6.25539050e-02],
       [-2.91498125e-01,  7.66226426e-02,  1.30788106e-02,
         1.21028148e-01,  1.15245983e-01,  1.78283691e-01,
         3.21665317e-01,  1.02033220e-01,  1.45019650e-01,
        -4.34377521e-01],
       [ 1.68283820e-01, -3.09178401e-02,  6.57193065e-02,
         1.83798093e-02, -7.25373849e-02,  1.67001039e-02,
         1.26208022e-01, -3.93356174e-01,  6.57608733e-02,
         2.00210825e-01],
       [ 9.03305188e-02, -3.31516951e-01,  1.92073271e-01,
         1.28750160e-01, -5.60461223e-01,  2.02110186e-01,
         2.47376963e-01,  1.91605061e-01, -1.78822219e-01,
        -2.13595301e-01],
       [ 1.79702878e-01, -2.85403341e-01, -3.79967839e-01,
        -3.18726778e-01, -3.74518842e-01, -2.43183710e-02,
         2.44848490e-01,  2.43323982e-01, -1.34915501e-01,
        -3.29058856e-01],
       [-4.59100045e-02,  1.11619122e-01,  2.06353113e-01,
        -3.57645638e-02, -2.68476754e-01,  1.54805463e-02,
         1.19950175e-01,  2.68370628e-01,  1.79247513e-01,
        -5.92741966e-01],
       [-1.18884221e-02,  3.30010265e-01, -3.34912717e-01,
        -1.25106558e-01, -3.13555360e-01,  3.38253379e-01,
         1.71396181e-01, -2.64308155e-01, -2.28820279e-01,
         9.15235952e-02],
       [-1.96573287e-02,  1.09934568e-01,  3.11464965e-01,
        -3.46049458e-01, -4.22263414e-01,  1.13516875e-01,
        -1.10042468e-01,  2.11521327e-01,  1.19895123e-01,
         1.22512609e-01],
       [ 4.91617508e-02, -1.46876182e-02, -1.10340687e-02,
        -1.40922934e-01,  2.40739062e-01,  2.78086541e-03,
        -3.04801941e-01,  1.87239528e-01, -4.08947110e-01,
         1.64359197e-01],
       [-1.23397402e-01,  1.75705373e-01, -9.41747651e-02,
        -3.65135044e-01,  1.11648791e-01,  1.61744043e-01,
        -4.55581516e-01,  2.70987302e-01, -2.47416683e-02,
        -8.45107734e-02],
       [ 8.28784853e-02,  1.96988985e-01, -4.19137686e-01,
         3.10681686e-02, -2.06089336e-02,  2.11591825e-01,
        -3.51008564e-01,  1.58253074e-01, -8.44351202e-02,
         5.98762669e-02],
       [-2.11891711e-01,  7.34165385e-02,  6.44661207e-03,
         8.63067247e-03, -2.31139749e-01, -1.74408883e-01,
         2.16911077e-01, -1.50186718e-01, -3.73014003e-01,
        -3.68643641e-01],
       [ 2.34784991e-01,  3.78104374e-02,  2.63703734e-01,
        -1.64538354e-01,  7.57270120e-03, -3.90653104e-01,
        -1.43432751e-01,  2.12894857e-01,  1.69312999e-01,
        -2.59785682e-01],
       [ 9.35157537e-02, -2.01519594e-01,  1.75988719e-01,
        -2.92130262e-01,  2.28341654e-01,  1.91730857e-02,
         7.07815886e-02,  1.37614071e-01,  1.98204890e-01,
        -1.40204385e-01],
       [-2.83252615e-02,  3.78034413e-02, -1.35061309e-01,
        -1.62101313e-01, -6.32801875e-02,  2.15638056e-01,
        -2.08552018e-01, -9.57552046e-02, -3.08166534e-01,
        -1.19683079e-01],
       [-1.83200374e-01,  3.76025349e-01, -1.16502978e-01,
        -4.83475059e-01, -6.25678301e-02,  2.37899646e-01,
        -3.66046041e-01,  2.22620890e-01, -4.73464787e-01,
        -1.89155549e-01],
       [-8.66091922e-02, -9.89375189e-02, -1.94817007e-01,
         9.37115997e-02,  1.39805645e-01, -1.35264009e-01,
        -3.90744150e-01,  2.31051669e-01,  1.50245175e-01,
         1.29960239e-01],
       [-1.88448876e-01, -2.16989845e-01, -1.20259278e-01,
        -2.40481123e-01,  2.78745115e-01, -1.64214984e-01,
        -1.34227902e-01, -2.03885585e-01, -2.50468135e-01,
         1.82746530e-01],
       [-3.06171596e-01,  2.18999043e-01, -3.19759518e-01,
        -1.01892829e-01, -1.74568921e-01, -2.91039079e-01,
         2.94285625e-01, -2.74888486e-01, -6.43369108e-02,
         3.48723471e-01],
       [ 2.33909730e-02, -6.47128597e-02, -4.31493759e-01,
         2.13833421e-01, -1.19319089e-01, -2.89464384e-01,
        -1.12514161e-01,  3.63984853e-01, -4.55533743e-01,
         2.60717385e-02],
       [-2.95785904e-01, -2.07676753e-01,  6.98602349e-02,
        -8.48661549e-03, -3.83481503e-01, -6.69387057e-02,
         2.18238056e-01,  2.49815926e-01, -2.70438492e-01,
        -4.34468538e-01],
       [ 4.70142439e-02, -6.06754273e-02, -5.58634579e-01,
         1.58430338e-01,  3.16151381e-02, -1.07376061e-01,
        -3.49974513e-01, -7.29726180e-02, -3.56006436e-02,
         3.03892493e-01],
       [-3.34823012e-01, -2.51858741e-01, -2.25251958e-01,
         2.37957463e-01,  2.57917583e-01,  3.60314548e-02,
        -3.53990138e-01, -7.66057670e-02,  2.54444063e-01,
         2.12081417e-01],
       [-8.70895386e-02,  5.11322869e-03, -2.91756868e-01,
        -2.07769036e-01,  2.97158152e-01, -4.08434540e-01,
        -2.68067390e-01,  2.00924113e-01, -4.38634753e-01,
         7.55239949e-02],
       [ 1.79762044e-03, -3.45835894e-01,  4.62487787e-02,
         1.29018426e-01,  2.00612947e-01,  2.40591943e-01,
         5.16790375e-02, -4.29002047e-01,  2.19572246e-01,
         1.05557097e-02],
       [-3.16552520e-01,  8.77547190e-02,  4.17853892e-02,
        -4.08354223e-01, -1.12028815e-01,  2.61908531e-01,
         7.60191679e-02, -2.38478839e-01,  1.47306379e-02,
         3.49691689e-01],
       [ 2.84786552e-01, -6.80592703e-03, -4.72505003e-01,
        -2.76086569e-01, -3.66689682e-01,  1.20811805e-01,
        -1.53523177e-01, -1.52481034e-01, -2.68145621e-01,
         4.70004231e-02],
       [-4.45838310e-02,  7.93924257e-02, -7.34841526e-02,
         2.08099082e-01, -3.68329078e-01,  2.06218287e-01,
        -4.25800115e-01,  1.48879498e-01, -1.16511500e-02,
         2.36930773e-01],
       [ 1.89541459e-01, -1.49233416e-01,  2.09065124e-01,
         2.00762693e-02, -2.88338929e-01,  1.99728936e-01,
        -1.97521448e-01,  1.43347621e-01, -2.36340642e-01,
        -3.29071403e-01],
       [-1.43338472e-01,  2.85153925e-01,  3.88877541e-02,
         2.14123070e-01, -3.78108233e-01,  2.49663651e-01,
        -1.90097138e-01, -2.09107697e-01, -3.31528336e-01,
         2.41039231e-01],
       [-9.46740285e-02,  3.46576244e-01, -3.03534240e-01,
        -5.37712157e-01, -3.05850301e-02, -6.98754862e-02,
         1.39591247e-01, -3.61753345e-01, -2.54348367e-01,
        -1.73599243e-01],
       [-7.52603561e-02, -4.64169860e-01, -3.11523080e-01,
        -3.30096390e-03, -2.48671651e-01,  2.67981470e-01,
         3.46507430e-02,  6.33662045e-02, -1.73883047e-02,
         1.57723546e-01],
       [-2.55069643e-01, -3.28791529e-01, -1.91593617e-01,
         1.98152009e-02,  2.38621473e-01,  1.73862070e-01,
        -3.37664545e-01,  8.53433087e-02, -2.10560709e-01,
         3.63261439e-02],
       [-5.02947047e-02, -2.81642228e-01, -2.31026471e-01,
        -2.10499510e-01, -1.52135611e-01, -2.49384075e-01,
        -1.82276487e-01,  3.45350131e-02,  1.10356145e-01,
         6.95606992e-02],
       [-9.24093127e-02,  4.59271222e-02,  3.93759608e-02,
         2.91396320e-01, -6.77868873e-02, -4.35065717e-01,
        -3.41314614e-01, -2.61516005e-01, -1.26054004e-01,
         2.79389948e-01],
       [-1.65248692e-01,  3.01901400e-01, -4.27287251e-01,
        -2.14315772e-01,  1.85965419e-01, -6.04463667e-02,
         6.02773912e-02,  2.65717179e-01, -4.41522062e-01,
         1.70508966e-01],
       [-2.63828665e-01, -8.64971578e-02, -2.78461576e-01,
         6.19269870e-02,  1.75784558e-01,  2.26351202e-01,
         2.23095179e-01, -7.56275430e-02,  1.32076845e-01,
        -9.38763022e-02],
       [-2.86855221e-01,  3.81994694e-02, -3.11091870e-01,
        -7.58372173e-02,  2.34059244e-01, -1.56661533e-02,
         7.35731423e-02,  1.77167624e-01,  2.17671648e-01,
        -1.91961788e-03],
       [ 1.86832726e-01,  1.35448843e-01,  3.00116807e-01,
        -1.73696235e-01, -9.37856138e-02,  1.51225656e-01,
        -9.32918712e-02, -3.50246668e-01, -3.24846774e-01,
        -5.62554505e-03],
       [-3.51985544e-01,  1.26511768e-01,  1.49354547e-01,
         1.93097010e-01,  8.64192247e-02,  7.05208927e-02,
        -5.19544959e-01,  2.70478725e-01,  1.65952802e-01,
         1.64275318e-02],
       [-1.92465544e-01,  4.02155459e-01, -1.52402923e-01,
        -1.99932814e-01,  3.47709596e-01,  6.46834169e-03,
         3.87867332e-01,  1.28063947e-01, -3.45817953e-01,
        -4.91191477e-01],
       [ 6.11011274e-02, -2.56536752e-01, -3.15822661e-02,
        -1.52900815e-01, -3.21199924e-01,  1.59323826e-01,
         3.27706814e-01,  1.36229798e-01, -2.06773639e-01,
         2.47566059e-01],
       [-2.85091579e-01, -1.90595418e-01,  2.74384946e-01,
        -1.67294994e-01, -1.41146630e-01,  2.19166890e-01,
        -7.70386904e-02, -6.47071823e-02,  9.02576721e-06,
        -3.55082661e-01],
       [-8.33381787e-02, -2.64859885e-01,  1.22059382e-01,
        -2.96119362e-01, -2.02185154e-01, -2.04638451e-01,
        -2.34781384e-01,  4.05640006e-01,  3.95843871e-02,
        -5.00460327e-01],
       [-6.75931275e-02, -1.52381118e-02, -5.84617443e-02,
        -7.04567367e-03, -4.38545316e-01, -1.61278471e-01,
        -8.04644227e-02,  2.87377626e-01, -3.14729095e-01,
         1.63273811e-01],
       [-4.85402755e-02, -1.85809851e-01,  2.34951764e-01,
         1.07792743e-01, -1.19594529e-01, -4.99206930e-01,
        -3.42985153e-01,  2.04150021e-01, -1.40759617e-01,
         4.17127758e-02],
       [ 4.76027206e-02,  2.02197984e-01, -2.56469403e-03,
         5.16320094e-02, -4.73078012e-01,  1.83827221e-01,
         2.54296690e-01, -5.82304932e-02,  2.67515868e-01,
        -3.10604237e-02],
       [ 2.36623093e-01, -2.22096846e-01, -1.41983896e-01,
        -3.52556050e-01, -3.20000276e-02, -9.11143273e-02,
         5.81668355e-02,  1.78689566e-02,  2.08153933e-01,
        -3.37888002e-01],
       [-4.14484113e-01,  5.72569035e-02,  6.03880510e-02,
         2.75157154e-01, -6.30981252e-02, -2.10512072e-01,
         1.08315110e-01,  2.03937560e-01, -6.50893524e-02,
        -2.20293328e-01],
       [ 2.76402891e-01,  1.74387738e-01,  1.06817611e-01,
        -8.05123001e-02, -1.91160321e-01, -3.90674204e-01,
        -2.64012404e-02,  2.29674041e-01,  2.47093081e-01,
         1.52173817e-01],
       [ 1.88219309e-01, -1.60041615e-01, -2.60463208e-01,
        -2.78181106e-01,  1.74190581e-01,  1.50532395e-01,
         2.39792839e-01, -1.97646185e-03, -3.39136064e-01,
         2.27342233e-01],
       [ 1.81723490e-01,  2.82607019e-01,  1.06002472e-01,
        -8.10205936e-03, -3.41993093e-01, -1.69809997e-01,
         7.41727501e-02, -3.31415057e-01, -1.56984672e-01,
         6.68503791e-02],
       [-3.35356623e-01,  2.68279344e-01, -6.96995184e-02,
        -5.31341374e-01,  1.01676062e-01,  1.10403292e-01,
        -1.39216095e-01,  1.23701543e-01, -2.04578757e-01,
        -3.05327803e-01],
       [ 1.52495489e-01,  2.40899324e-01,  8.64105150e-02,
        -4.70167339e-01,  3.36731046e-01, -2.88031578e-01,
         1.22474372e-01, -2.08129853e-01, -4.26135898e-01,
         1.23008616e-01],
       [ 2.19447419e-01, -4.25975442e-01, -2.69947082e-01,
        -2.57135965e-02, -3.49941641e-01,  1.04076453e-01,
        -3.61683359e-03, -2.87522405e-01,  2.05251873e-02,
        -1.03612281e-01],
       [ 1.96892500e-01, -8.73491615e-02,  1.34108305e-01,
         1.87838942e-01, -7.36293867e-02, -1.67612627e-01,
         5.34972176e-02, -1.28114730e-01,  2.39227191e-01,
         1.52084604e-01],
       [-2.42285430e-01, -1.80974156e-01,  2.19037905e-01,
        -2.78696179e-01, -1.71007901e-01, -2.07579866e-01,
        -2.65557259e-01,  2.27460340e-01, -1.89204544e-01,
         1.05596513e-01],
       [-3.00730944e-01, -8.07805061e-02, -3.20149034e-01,
         1.02798693e-01,  3.85463059e-01,  1.40067115e-01,
        -3.96318287e-01,  1.27322385e-02, -5.26268661e-01,
         2.65310705e-01],
       [ 1.77418590e-02, -3.32880855e-01,  1.96217701e-01,
         9.38635245e-02,  9.29819867e-02, -3.18250328e-01,
        -9.68264043e-02,  1.16937332e-01,  7.45990947e-02,
         1.64286569e-01],
       [-2.65594870e-01,  1.00067541e-01,  1.91965532e-02,
        -3.47705372e-03,  1.33793931e-02,  1.58512205e-01,
        -1.75091654e-01, -8.55431631e-02,  1.84122652e-01,
         2.84497470e-01],
       [-2.72124559e-01, -1.85961753e-01, -4.86050874e-01,
         1.09921977e-01, -1.12289395e-02,  2.26568803e-01,
         2.64995009e-01,  1.38251394e-01, -4.94774207e-02,
        -6.99937791e-02],
       [ 3.74109745e-02,  2.47385919e-01,  9.68004614e-02,
         2.26332277e-01,  1.86111126e-02, -2.54699349e-01,
         1.74074680e-01,  1.21922724e-01, -3.51350099e-01,
        -4.60877508e-01],
       [ 2.35768378e-01,  4.54334393e-02,  5.33979014e-02,
         1.27460673e-01, -7.01661766e-01, -8.58926475e-02,
        -2.22885370e-01, -4.78002690e-02, -2.42045864e-01,
        -4.47942078e-01],
       [-3.86173278e-01,  2.65619904e-01,  2.26018026e-01,
         8.76465663e-02,  2.24053368e-01, -3.48186679e-02,
        -2.83748150e-01,  5.31676486e-02, -3.37433070e-01,
        -3.55700403e-01],
       [-3.32953036e-01,  1.71580344e-01,  1.71892568e-01,
         2.22348973e-01, -1.97261035e-01,  7.07722157e-02,
        -2.85758406e-01, -1.79066554e-01,  1.86767325e-01,
        -2.61755675e-01],
       [-1.66537970e-01, -7.62976781e-02,  1.34396255e-01,
         2.12341934e-01, -3.06788445e-01, -2.03811824e-01,
         1.15235768e-01, -2.25298047e-01,  8.40689316e-02,
        -1.22556776e-01],
       [-1.55728787e-01,  8.89719203e-02,  2.23977998e-01,
        -3.28815132e-02, -4.55562890e-01, -4.49441552e-01,
         1.48577452e-01, -2.08408982e-01, -3.42789024e-01,
         7.36563057e-02],
       [-2.99370289e-01,  2.14110270e-01, -2.54100394e-02,
         5.72636165e-03,  2.37848416e-01, -1.61451086e-01,
        -2.57838815e-01, -1.68274287e-02,  2.10511789e-01,
        -1.72612481e-02],
       [ 2.54113376e-01,  1.52949333e-01, -3.42802823e-01,
         8.62261131e-02,  3.54665339e-01, -2.70900726e-01,
        -2.58765638e-01, -4.86204416e-01, -6.25081882e-02,
        -2.48219520e-01],
       [-2.11227015e-02,  3.11614543e-01, -7.31830969e-02,
         5.06604388e-02, -3.96664739e-01, -5.11043549e-01,
        -7.12758601e-02,  5.23637906e-02, -1.72585770e-01,
        -1.12323724e-01],
       [-2.88673878e-01, -3.41729730e-01,  8.76247585e-02,
        -3.31366479e-01,  2.95934714e-02, -1.40673220e-01,
        -2.64001906e-01, -3.75759751e-01, -2.20368892e-01,
         1.65278658e-01],
       [-2.61867434e-01, -1.25094682e-01, -1.88089572e-02,
        -2.86843389e-01,  2.07129523e-01,  9.33256671e-02,
         2.02610716e-01,  1.14842623e-01,  2.93302596e-01,
        -3.47602844e-01],
       [-5.24003319e-02, -1.82070181e-01,  2.64166474e-01,
         1.53725669e-01,  1.60775915e-01, -3.94912884e-02,
         5.98972812e-02,  1.27977923e-01,  3.62922251e-02,
        -2.37544343e-01],
       [ 2.83543635e-02, -3.82527709e-01,  1.03175066e-01,
        -3.20723563e-01, -3.55718613e-01,  3.48593853e-02,
         1.52670860e-01,  2.37062365e-01, -8.84934217e-02,
         1.07653186e-01],
       [-8.02218262e-03,  1.98889375e-01,  2.36755818e-01,
        -3.64727765e-01, -1.66197829e-02, -2.52599061e-01,
         3.08327228e-02, -2.65515476e-01, -4.67243642e-02,
        -9.93776992e-02],
       [-2.88613439e-01,  1.60547391e-01,  2.01603085e-01,
         3.29991728e-01, -5.04409075e-01,  2.12129638e-01,
         1.29706696e-01,  3.26536223e-02, -5.79475239e-02,
        -3.03462207e-01],
       [ 1.73510164e-01, -4.73308414e-01,  3.45571339e-02,
        -2.80015141e-01, -1.42622843e-01, -3.91261354e-02,
        -4.54723686e-02, -4.02801454e-01, -4.28644791e-02,
         2.38903821e-01],
       [-7.78434575e-02,  7.01967180e-02, -3.48449171e-01,
        -5.15536427e-01,  3.84873636e-02, -3.21855098e-01,
         3.03351879e-01, -1.55083299e-01,  2.30123237e-01,
        -8.70974036e-04],
       [-2.91565925e-01, -1.71608552e-01, -3.24349612e-01,
         2.24242955e-01,  2.14290485e-01, -3.38350609e-02,
         1.36947989e-01,  1.28424913e-02,  5.85890487e-02,
        -4.89389390e-01],
       [-3.49002093e-01, -2.24898815e-01,  2.38856182e-01,
         2.56595671e-01,  2.10459158e-02,  1.75605848e-01,
        -3.92496169e-01, -1.32802978e-01,  8.51763859e-02,
         4.41880003e-02],
       [ 1.49462268e-01, -8.83552507e-02, -2.92263716e-01,
         3.37864727e-01,  1.99508980e-01, -1.99407041e-02,
         1.95664078e-01, -3.62197131e-01, -2.76591420e-01,
        -1.89328298e-01],
       [-1.29553089e-02,  2.76988298e-01, -4.49188322e-01,
         4.61517796e-02, -4.40331697e-01,  3.83680701e-01,
         1.05097756e-01, -5.92100248e-02, -3.65615457e-01,
        -1.52612358e-01],
       [ 1.82668746e-01,  2.31690034e-01, -2.92528003e-01,
        -2.55986303e-01,  8.32752287e-02,  2.51465917e-01,
        -1.01054765e-01, -7.06190243e-03, -2.96387017e-01,
        -4.84422773e-01],
       [-1.21400669e-01, -3.13585289e-02, -1.49048269e-01,
        -2.05441698e-01,  2.27967411e-01,  1.49613246e-01,
        -1.65963650e-01, -3.29703897e-01,  8.17758366e-02,
         9.74359736e-02],
       [ 2.89199293e-01, -7.04863593e-02, -1.22465491e-01,
        -7.15844659e-03, -5.14583707e-01,  6.51249960e-02,
        -1.20044194e-01,  3.26011032e-01, -1.46292567e-01,
        -1.34193763e-01],
       [ 2.07244024e-01,  2.18047187e-01, -3.28369349e-01,
         3.54067504e-01, -1.34166539e-01, -3.58019650e-01,
         3.81010287e-02,  1.76515803e-01, -2.42457300e-01,
        -2.58901734e-02],
       [-1.83583036e-01,  9.63129923e-02,  1.80392161e-01,
        -3.12118948e-01,  1.35678813e-01, -4.84556109e-01,
         2.94969007e-02,  2.10692644e-01,  2.85272449e-01,
         7.40195960e-02],
       [ 2.14306619e-02,  2.23890185e-01, -5.03564402e-02,
        -5.09863198e-01,  7.22838640e-02,  3.17906708e-01,
        -2.77185917e-01,  9.14458837e-03,  2.00836927e-01,
         2.31873438e-01],
       [ 4.97248918e-02,  4.68859673e-02, -3.31564605e-01,
         2.14749053e-01, -2.34072153e-02,  2.28042364e-01,
        -2.55816698e-01,  1.08321579e-02,  2.10366532e-01,
         1.82601720e-01],
       [-2.62626056e-02, -1.21024236e-01,  1.02919996e-01,
        -5.97031154e-02,  1.77486595e-02,  6.08048066e-02,
         2.15883568e-01, -5.29421747e-01, -1.41826728e-02,
         6.10244796e-02],
       [ 2.30102479e-01,  1.66655496e-01, -8.18227942e-04,
        -3.35090339e-01,  1.70878336e-01, -1.07719235e-01,
        -2.74331063e-01,  1.52734920e-01, -4.22130316e-01,
         4.29905392e-03],
       [ 1.01837583e-01, -1.64990351e-01,  4.56752926e-02,
        -3.95802766e-01,  1.88975126e-01, -7.27127343e-02,
         1.77403137e-01, -4.33928579e-01,  6.07757755e-02,
        -8.86140168e-02],
       [-4.19868499e-01, -2.48233587e-01, -2.57547587e-01,
         3.02368283e-01,  1.13392621e-01, -1.54164657e-01,
        -2.27882534e-01,  1.99300513e-01, -4.92941290e-01,
         3.32245409e-01],
       [-1.01652101e-01, -4.19613987e-01,  3.48329127e-01,
        -2.74736881e-01,  2.61124879e-01,  6.26765192e-02,
         4.22035269e-02, -5.20829000e-02, -3.64311099e-01,
        -1.79357663e-01],
       [ 1.39732897e-01, -2.66956538e-01, -3.84257436e-01,
        -1.13201685e-01,  7.89685026e-02,  2.69948304e-01,
         1.01909287e-01, -9.96095836e-02,  1.48862645e-01,
         1.53904662e-01],
       [ 1.36547923e-01, -2.41517991e-01,  2.04632416e-01,
        -1.64432637e-02, -1.54070064e-01, -1.83150306e-01,
        -2.03562394e-01, -2.21244648e-01,  1.00092985e-01,
        -1.01901151e-01],
       [-5.50452732e-02, -6.54754788e-02,  1.59336641e-01,
        -3.78038645e-01,  1.72017649e-01, -2.92336732e-01,
         3.41575712e-01, -4.80901748e-01, -3.72718275e-01,
        -1.70941189e-01]], dtype=float32), array([ 0.01098968,  0.03846952,  0.00500416, -0.03767109, -0.01391269,
       -0.02019179, -0.05111629, -0.06184113,  0.04913256, -0.00327604],
      dtype=float32)]

from keras.utils.vis_utils import plot_model
plot_model(model, show_shapes=True, to_file='./model.png')


Jupyter Notebookでの表示を示す。
tensorflow_keras_41_180526.png
  1. 2018年05月26日 07:52 |
  2. TensorFlow, Keras
  3. | トラックバック:0
  4. | コメント:0

TensorFlow + Kerasを使ってみた3(以前使用したCNNを使った学習)

TensorFlow + Kerasを使ってみた2(実践編)”の続き。

TensorFlow + Kerasを使ってみた2(実践編)”で使用した”Keras / Tensorflowで始めるディープラーニング入門”の Python コードをそのまま再利用させて頂いて、”「ゼロから作るDeep Learning」の畳み込みニューラルネットワークのハードウェア化5(再度学習)”の CNN を学習していこう。これは MNIST の手書き数字を認識するCNN で畳み込み層の特徴マップの数は 10 個となっている。

まずは、Python コードを示す。

# My Mnist CNN
# Conv2D - ReLU - MaxPooling - Dence - ReLU - Dence
# 2018/05/25 by marsee
# Keras / Tensorflowで始めるディープラーニング入門 https://qiita.com/yampy/items/706d44417c433e68db0d
# のPythonコードを再利用させて頂いている

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

img_rows, img_cols = 28, 28

(x_train, y_train), (x_test, y_test) = mnist.load_data()

#Kerasのバックエンドで動くTensorFlowとTheanoでは入力チャンネルの順番が違うので場合分けして書いています
if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

y_train = y_train.astype('int32')
y_test = y_test.astype('int32')
y_train = keras.utils.np_utils.to_categorical(y_train, num_classes)
y_test =  keras.utils.np_utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(10, kernel_size=(5, 5),
                 activation='relu',
                 input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,
          verbose=1, validation_data=(x_test, y_test))


結果を示す。

('x_train shape:', (60000, 28, 28, 1))
(60000, 'train samples')
(10000, 'test samples')
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
60000/60000 [==============================] - 16s 265us/step - loss: 0.2680 - acc: 0.9204 - val_loss: 0.1202 - val_acc: 0.9640
Epoch 2/12
60000/60000 [==============================] - 16s 261us/step - loss: 0.0822 - acc: 0.9754 - val_loss: 0.0626 - val_acc: 0.9792
Epoch 3/12
60000/60000 [==============================] - 16s 260us/step - loss: 0.0558 - acc: 0.9830 - val_loss: 0.0476 - val_acc: 0.9845
Epoch 4/12
60000/60000 [==============================] - 15s 256us/step - loss: 0.0429 - acc: 0.9869 - val_loss: 0.0470 - val_acc: 0.9842
Epoch 5/12
60000/60000 [==============================] - 15s 254us/step - loss: 0.0349 - acc: 0.9891 - val_loss: 0.0369 - val_acc: 0.9867
Epoch 6/12
60000/60000 [==============================] - 17s 279us/step - loss: 0.0290 - acc: 0.9910 - val_loss: 0.0376 - val_acc: 0.9871
Epoch 7/12
60000/60000 [==============================] - 16s 274us/step - loss: 0.0238 - acc: 0.9927 - val_loss: 0.0372 - val_acc: 0.9877
Epoch 8/12
60000/60000 [==============================] - 16s 261us/step - loss: 0.0204 - acc: 0.9940 - val_loss: 0.0328 - val_acc: 0.9884
Epoch 9/12
60000/60000 [==============================] - 15s 251us/step - loss: 0.0172 - acc: 0.9947 - val_loss: 0.0334 - val_acc: 0.9878
Epoch 10/12
60000/60000 [==============================] - 15s 254us/step - loss: 0.0146 - acc: 0.9957 - val_loss: 0.0342 - val_acc: 0.9889
Epoch 11/12
60000/60000 [==============================] - 15s 254us/step - loss: 0.0123 - acc: 0.9966 - val_loss: 0.0359 - val_acc: 0.9878
Epoch 12/12
60000/60000 [==============================] - 15s 257us/step - loss: 0.0109 - acc: 0.9968 - val_loss: 0.0340 - val_acc: 0.9889


accuracy は 0.9889 だった。

次に、accuracy と loss のグラフを書いてみた。これも”Keras / Tensorflowで始めるディープラーニング入門”の Python コードをそのまま再利用させて頂いている。

# Keras / Tensorflowで始めるディープラーニング入門 https://qiita.com/yampy/items/706d44417c433e68db0d
# のPythonコードを再利用させて頂いている

%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt

plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

# plot the loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()


描画されたグラフを示す。
tensorflow_keras_31_180525.png

多少 train と test の間が離れているが、test の accuracy は振動はしているが、増加はしているので、まだ過学習ではないと思うがいかがだろうか?

最後に、Jupyter Notebook の画像を貼っておく。
tensorflow_keras_27_180525.png
tensorflow_keras_28_180525.png
tensorflow_keras_29_180525.png
tensorflow_keras_30_180525.png
  1. 2018年05月25日 04:42 |
  2. TensorFlow, Keras
  3. | トラックバック:0
  4. | コメント:0

TensorFlow + Kerasを使ってみた2(実践編)

TensorFlow + Kerasを使ってみた1(インストール編)”の続き。

前回は、TensorFlow + Keras と Jupyter Notebook をインストールして、「Keras / Tensorflowで始めるディープラーニング入門」を参考にしてやってみることにした。今回は、「Keras / Tensorflowで始めるディープラーニング入門」のサンプルをJupyter Notebook でやってみよう。

最初に、「手書き文字の認識(1): 全結合層のみ」の keras_mnist_pixeldata.ipynb をやってみよう。
tensorflow_keras_17_180523.png
tensorflow_keras_18_180523.png
tensorflow_keras_19_180523.png
tensorflow_keras_20_180523.png
tensorflow_keras_21_180523.png
tensorflow_keras_22_180523.png
tensorflow_keras_23_180523.png

全てうまく行った。

次に、「手書き文字の認識(2): CNNモデル」の keras_mnist_cnn.ipynb をやってみた。
tensorflow_keras_24_180523.png
tensorflow_keras_25_180523.png
tensorflow_keras_26_180523.png

こちらもうまく実行することができた。
  1. 2018年05月23日 05:04 |
  2. TensorFlow, Keras
  3. | トラックバック:0
  4. | コメント:0