CIFAR-10はkerasでダウンロードできるデータセットで、飛行機や車、鳥、猫などの10カテゴリーの写真6万枚が入っている。約8000万枚の画像が「80 Million Tiny images」というWEBサイトで公開されているが、そこから6万枚の画像を抽出し、ラベル付けしたデータセットがCIFAR-10。画像はフルカラーで32✖️32ピクセルと小さな画像になっている。
実行結果
今回のプログラムではCIFAR-10のデータをMLP(多層パーセプトロン)のアルゴリズムで、分類していく。
この実行結果ではあまり正解率は良くないがどんな画像データでも、一次元の配列データに変換すれば、MLPのモデルを利用してディープラーニングが実践できることを表している。
データの読み込みと各データの処理
import matplotlib.pyplot as plt
import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout
num_classes = 10
im_rows = 32
im_cols = 32
im_size = im_rows * im_cols * 3
# データを読み込む --- (*1)
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
# データを一次元配列に変換 --- (*2)
X_train = X_train.reshape(-1, im_size).astype('float32') / 255
X_test = X_test.reshape(-1, im_size).astype('float32') / 255
# ラベルデータをOne-Hot形式に変換
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
CIFAR-10をダウンロードするためcifar10.load_data()で各変数に格納している。そして、画像データをreshape()で一次元に変換。ラベルデータをone-hotベクトルに変換。
im_size = im_rows * im_cols *3の部分は一つの画像を一次元にしたとき、区切られる要素数の値をim_sizeに格納している。今回の画像は一枚あたり32✖️32のフルカラーの画像なので、配列で表現すると、フルカラーは3つの要素で表現されるので3つの要素を持った一次元の配列を32個(32列)持った二次元の配列が32行あるということになる。
[0,0,0]・・・・・・・・・・・3つの要素を持った配列。
[[0,0,0],[0,0,0]....[0,0,0]]・・・3つの要素の配列を32個持った二次元配列
1 2 32(個)
[[[0,0,0],[0,0,0]....[0,0,0]],・・・上記の配列が32行ある三次元配列
[[0,0,0],[0,0,0]....[0,0,0]],
........
[[0,0,0],[0,0,0]....[0,0,0]]]
すぐ上の三次元配列によって一つの画像は表現されているので、ここから一次元にした場合の要素数を計算すると
32 ✖️ 32 ✖️ 3 で求めることができる。
モデル定義
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(im_size,)))
model.add(Dense(num_classes, activation='softmax'))
model.compile(
loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
学習
hist = model.fit(X_train, y_train,
batch_size=32, epochs=50,
verbose=1,
validation_data=(X_test, y_test))
モデル評価
score = model.evaluate(X_test, y_test, verbose=1)
print('正解率=', score[1], 'loss=', score[0])
学習の様子をグラフへ描画
plt.plot(hist.history['val_accuracy'])
plt.title('Accuracy')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plt.title('Loss')
plt.legend(['train', 'test'], loc='upper left')
plt.show()