TensorflowのBatchNormalizationは難しい

はじめに

久しぶりにTensorflowをいじっていて、BatchNormalizationの挙動を確認した際の備忘録です。
実際の挙動確認や理論的なお話は下記によくまとまっています。
qiita.com
qiita.com

本ブログでは、ダメな実装について紹介したいと思います。

ダメな例

色々調べていると、正しく動く例はたくさん出てくるのですが、
当然といえば当然なのですが、ダメな例は殆ど紹介されていません。
早速ですが、正しく動作しない例を紹介します。

Keras API + カスタムEstimator

import tensorflow as tf

#define original model
def model_fn(features, labels, mode):
  
  training = (mode == tf.estimator.ModeKeys.TRAIN)
  x = tf.keras.layers.BatchNormalization()(features,training=training)
  y = tf.keras.layers.Dense(1)(x)
  predictions = {
      "prob": y,
  }
  
  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  loss = tf.reduce_mean(tf.losses.mean_squared_error(labels,y))

  if mode == tf.estimator.ModeKeys.TRAIN:
      optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
     
      # set batch normalization parameters to train
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

      return tf.estimator.EstimatorSpec(
          mode=tf.estimator.ModeKeys.TRAIN,
          loss=loss,
          train_op=train_op)
    
  return tf.estimator.EstimatorSpec(
      mode=mode, loss=loss)
  
# create estimator
estimator = tf.estimator.Estimator(
    model_fn=model_fn)

input_fn = lambda:(tf.constant([[0], [1], [2], [3]], dtype=tf.float32),tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32))

estimator.train(input_fn,steps=5000)
result = estimator.evaluate(input_fn,steps=4)

この例は、最新のKeras APIであるtf.keras.layers.BatchNormalizationを使っています。
そしてCustomEstimatorを作って学習しています。
実際に動かしてみると、evaluateのlossがtrainとはかけ離れた値になります。
どうやら、BatchNormalizationが学習されていないようです。
問題の部分は下記です。

      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

この部分で、BatchNormalizationのパラメータを学習するように指定しているつもりなのですが、
実はできていないようです。
そして、正しく指定する方法がよくわかりませんでした。
ついつい、上記の様な実装をしてしまいがちですが、避けたほうが良いようです。

正しく動作する例

1. tf.layers.BatchNormalziation + カスタムEstimator

BatchNormalizationをkeras API ではなく、layersから持ってきます。
最初の実装の該当部分を以下に変えるだけでOKです。

 x = tf.layers.BatchNormalization()(features,training=training)

この場合は with tf.get_collection(tf.GraphKeys.UPDATE_OPS)によって
BatchNormalizationのパラメータが正しく指定されます。

2. tf.keras.layers.BatchNormaliztion + model_to_estimator

EstimatorのCustomを諦めれば簡単に実装できます。

import tensorflow as tf

inputs = tf.keras.Input(shape=(1,))
x = tf.keras.layers.BatchNormalization()(inputs)
outputs = tf.keras.layers.Dense(1)(x)

# create Keras Model
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam',
              loss='mse',
              metrics=['accuracy'])

# convert to estimator
estimator = tf.keras.estimator.model_to_estimator(model)

input_fn = lambda:(tf.constant([[0], [1], [2], [3]], dtype=tf.float32),tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32))

estimator.train(input_fn,steps=5000)
result = estimator.evaluate(input_fn,steps=4)
print(result)

おそらく公式の推奨はこれかと思われます。
記述量も短くてわかりやすいですね。

3. (おまけ) Estimatorを使わない

import tensorflow as tf
import numpy as np

inputs = tf.placeholder(shape=[None,1], dtype=tf.float32)
labels = tf.placeholder(shape=[None,1], dtype=tf.float32)
training = tf.placeholder(shape=[], dtype=tf.bool)

with tf.variable_scope('batch_normalization_test'):
  with tf.name_scope('for_train'):
    BN1 =  tf.layers.BatchNormalization(name="bn1")
    x = BN1(inputs,training =training)
    x = tf.layers.Dense(units=1,name="dense1")(x)
    BN2 =  tf.layers.BatchNormalization(name="bn2")
    x = BN2(x,training =training)
    output1 = tf.layers.Dense(units=1,name="dense2")(x)

with tf.variable_scope('batch_normalization_test', reuse=True):
  with tf.name_scope('for_test'):
    BN1 =  tf.layers.BatchNormalization(name="bn1")
    x = BN1(inputs,training =training)
    x = tf.layers.Dense(units=1,name="dense1")(x)
    BN2 =  tf.layers.BatchNormalization(name="bn2")
    x = BN2(x,training =training)
    output2 = tf.layers.Dense(units=1,name="dense2")(x)

loss = tf.reduce_mean(tf.losses.mean_squared_error(labels,output1))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
grads = optimizer.compute_gradients(loss)
with tf.control_dependencies(update_ops):
  train_op = optimizer.apply_gradients(grads)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

train_x = np.array([[-3.],[-2.],[-1.],[0.]])
train_y = np.array([[0.], [1.], [2.], [3.]])

for _ in range(5000):
  sess.run((loss,train_op,output1,output2),{inputs:train_x,labels:train_y,training:True})
#  print(sess.run(BN1.weights))
#  print(sess.run(BN2.weights))

print(sess.run((loss,output1,output2),{inputs:train_x,labels:train_y,training:True}))
print(sess.run((output1),{inputs:train_x,training:False}))
print(sess.run((output2),{inputs:train_x,training:False}))

途中いろいろ余計な部分がありますが、上記のような古いスタイルで問題なく動きます。
Keras API と Estimatorを諦めた方が分かりやすくなるのは気のせいでしょうか。

おわりに

Keras APIのBatchNormalizationを使うときは注意が必要です。
ドキュメントではあまり触れられていませんし、本家Issueでも投げやりな対応です。
(TF2.0が来るから?)
少し試す分にはKeras APIは便利ですが、
研究用途でがっつり使いたい場合は避けた方がよいかもしれません。

Google Colaboratoryで確認した際のソースを下記に公開しています。
良ければお試しください。
https://github.com/yoneken1/colab_tensorflow/blob/master/tensorflow_batch_normalization.ipynb