Batch Normalization Tensorflow
- Batch Normalization
- learning rate를 너무 높게 잡을 경우 gradient가 explode/vanish 하거나, 비정상 local minima에 빠지는 문제발생. 이는 parameter들의 scale 때문, Batch Normalization을 사용시 backpropagation 할 때 parameter의 scale에 영향을 받지 않게 되며, 따라서, learning rate를 크게 잡을 수 있게 되고 이는 빠른 학습을 가능하게 함.
- Batch Normalization의 경우 자체적인 regularization 효과가 있음.
- Activation 함수 적용전에 사용.
- 미니 배치 단위로 평균과 분산을 구함.
- 평균과 분산을 구해 normalize 시킴.
- gamma를 더하고 beta를 곱해 scale 과 shift를 시킴.
*gamma와 beta는 trainable 한 파라미터로 backprob이 가능.
- training 시 mini batch 단위에서 구한 평균과 분산을 이용해 normalize.
- mini batch 단위에서 구한 average와 variance를 저장.
- moving average 방식
- Test 시 train할때 구했던 평균과 분산의 평균을 이용해 normalize.
- 추가적으로 연산 시 m/(m-1)를 곱해주게 되는데, 이것은 학습 전체 데이터에 대한 분산이 아니라 mini batch 분산을 통해 전체 분산을 추정하기 때문에 통계학적으로 보정을 위해 베셀의 보정값을 곱해준다고 한다..
- Convolution에서는 activation function 전에 batch norm을 사용한다. 기본적인 형태는 Wx+b 형태로 들어가게 되는데 batch norm에서의 beta 역할이 Wx+b 에서의 b의 역할을 대체하기 때문에 이를 없애준다.
- tf.layers.batch_normalization(training=) 의 parameter 로 training이 들어감. -> train 시 True, test와 validation 시 False로 설정해 주어야함. 하지만 이렇게 셋팅 후 모델을 돌려보면 test와 vali 에서 값이 이상하게 나오게 된다. 이는 batch norm의 train 연산시 계산되는 평균과 분산을 업데이트 시켜줘야 하는데 이를 자동으로 해주지 못함. 즉 수동으로 업데이트 시켜야함!!
import mnist_function
import tensorflow as tf
import pandas as pd
import numpy as np
# set data count
count = [400, 4000, 4000, 4000, 4000, 4000, 4000, 4000, 4000, 4000]
# load data set
train_x, train_y, test_x, test_y , vali_x, vali_y = mnist_function.data_set(count)
print('train test vali')
print(len(train_x),len(test_x),len(vali_x))
#train data shuffle
train_x, train_y = mnist_function.data_shuffle(train_x, train_y)
# set parameters
batch_size = 32
learning_rate = 0.001
training_epochs = 5
# Network Model.
tf.set_random_seed(777)
#keep_prob = tf.placeholder(tf.float32)
batch_prob = tf.placeholder(tf.bool)
X = tf.placeholder(tf.float32, [None, 784])
X_img = tf.reshape(X, [-1, 28, 28, 1])
Y = tf.placeholder(tf.float32, [None, 10])
W1 = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))
L1 = tf.nn.conv2d(X_img, W1, strides=[1, 1, 1, 1], padding='SAME')
L1 = tf.layers.batch_normalization(L1, center=True, scale=True, training=batch_prob)
L1 = tf.nn.relu(L1)
L1 = tf.nn.max_pool(L1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
W2 = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))
L2 = tf.nn.conv2d(L1, W2, strides=[1, 1, 1, 1], padding='SAME')
L2 = tf.layers.batch_normalization(L2, center=True, scale=True, training=batch_prob)
L2 = tf.nn.relu(L2)
L2 = tf.nn.max_pool(L2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
W3 = tf.Variable(tf.random_normal([3, 3, 64, 128], stddev=0.01))
L3 = tf.nn.conv2d(L2, W3, strides=[1, 1, 1, 1], padding='SAME')
L3 = tf.layers.batch_normalization(L3, center=True, scale=True, training=batch_prob)
L3 = tf.nn.relu(L3)
L3 = tf.nn.max_pool(L3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
L3_flat = tf.reshape(L3, [-1, 128 * 4 * 4])
W4 = tf.get_variable("W4", shape=[128 * 4 * 4, 100], initializer=tf.contrib.layers.xavier_initializer())
b4 = tf.Variable(tf.random_normal([100]))L4 = tf.layers.batch_normalization(L3_flat , center=True, scale=True, training=batch_prob)
L4 = tf.nn.relu(tf.matmul(L4 , W4) + b4)
W5 = tf.get_variable("W15", shape=[100, 10], initializer=tf.contrib.layers.xavier_initializer())
b5 = tf.Variable(tf.random_normal([10]))
L5 = tf.layers.batch_normalization(L4, center=True, scale=True, training=batch_prob)
logits = tf.matmul(L5, W5) + b5
y_pred = tf.nn.softmax(logits)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y, logits=logits))
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print('Learning start')
train_total_loss = []
train_total_acc = []
validation_total_acc = []
validation_total_loss = []
for epoch in range(training_epochs):
avg_loss = 0
avg_acc = 0
feature_train_li = []
train_y_li = []
total_batch = int(len(train_x) / batch_size)
start_index = 0
finish_index = batch_size
for i in range(total_batch):
batch = mnist_function.next_batch(start_index, finish_index, train_x, train_y)
start_index += batch_size
finish_index += batch_size
#, keep_prob: 0.7
feed_dict = {X: batch[0], Y: batch[1], batch_prob: True}
train_loss, _, feature_train, y_label_train, train_acc = sess.run([cost, optimizer, L4, Y, accuracy], feed_dict=feed_dict)
avg_loss += train_loss / total_batch
avg_acc += train_acc / total_batch
if(epoch + 1 == training_epochs):
feature_train_li.append(feature_train)
train_y_li.append(y_label_train)
# vali_x, vali_y
vali_loss, vali_acc = sess.run([cost, accuracy], feed_dict={X: vali_x, Y: vali_y, batch_prob: False})
Batch Normalization 적용. 미적용.
Tensorflow API - https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
Paper - https://arxiv.org/pdf/1502.03167.pdf