Kerasで転移学習してSNSの画像を分類した話

授業でInstagram, Twitter, Facebookの画像を分類するアプリケーションを作成しました。

データ収集

instagram

まとめサイトでまとめられている、一般人の有名インスタグラマーのidを集め、その人たちの画像を約4万枚集めました。

twitter

1000RT以上、1000Fav以上の画像付きツイートの画像を約1万5千枚集めました。

facebook

これはチームメンバーがやってくれたのですが、友達をひたすらたどって画像を集めたようです。 結果として約1万5千枚集まりました。

モデル作成

フルで学習させると精度が中々出なかったので、ImageNetで学習済みのVGG16を使って転移学習をしました。転移学習なので、画像は結果的に3000枚ずつ、合計9000枚(training: 8100枚, validation: 900枚)を使いました。 転移学習はkerasを使うとめちゃくちゃ簡単にできます。

input_tensor = Input(shape=(100, 100, 3))
vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)

# FC層を構築
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.2))
top_model.add(Dense(3, activation='softmax'))

# VGG16とFCを接続
model = Model(input=vgg16.input, output=top_model(vgg16.output))

# 最後のconv層の直前までの層をfreeze
for layer in model.layers[:15]:
    layer.trainable = False

# Fine-tuningのときはSGDの方がよい
model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
              metrics=['accuracy'])

軽く解説すると、まず、今回は入力画像を100×100にしました。そしてカラーなので、Inputのshapeは(100, 100, 3)です。 続いてVGG16のモデルを読み込んでいますが、今回は転移学習するので、全結合層の部分は入りません。したがってinclude_top=False, そして、imagenetで学習済みの重みを使うのでweights=’imagenet’です。 そして、全結合層を構築して、先ほどのVGG16と接続しています。今回は、Instagram, Twitter, Facebookの3値分類を行うので最終層の出力は3になっています。 全結合層以外の部分は学習されてほしくないのでfreezeしています。(参照: https://keras.io/ja/getting-started/faq/#freeze

結果

f:id:wakanapo:20180123132205p:plain

epochによるaccuracyの変化

validationのデートセットで約68%ほどの精度を出すことができました。

http://35.190.172.104/で実際に試すことができます。

webアプリケーション

flaskを使ってサーバーを立てました。 これは特に問題なく動いたので、herokuでデプロイしようとしましたが、”30秒ルール” (参照: https://devcenter.heroku.com/articles/error-codes#h12-request-timeout )に引っかかってしまうようでうまく動きませんでした。 次にGoogle App Engineを試してみました。無料枠のstandard版だとpython2.7しかサポートされていません。キレそう。幸い完全に互換性のあるコードだったのでPython2.7でも問題なく動きました。ですが、GAEでは純粋なPythonしかサポートされていないそうです。numpyはbuilt-inとしてサポートされているようですが、OpenCVやKeras(Tensorflow)が使用できないので諦めました。 結局、Google Compute Engineを使ってサーバーを立てました。無料枠だけではメモリが足りなかったのですこしメモリを足しました。

写真が投稿されてから、学習した重みを読み込み推論を行うようになっているのでレスポンスが結構遅いです。続けて2枚目、3枚目の写真を投げた場合は、すでに重みが読み込まれているので少し早くなります。要改善と言った感じです。

おわりに

Deep Learning Day 2018でポスター発表をしました。

www.slideshare.net

https://github.com/wakanapo/SNS_clissfication/blob/master/vgg16.ipynb

参考

http://aidiary.hatenablog.com/entry/20170131/1485864665