Tensorflow权重初始化 [英] Tensorflow weight initialization

查看:196
本文介绍了Tensorflow权重初始化的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

关于TensorFlow网站上的 MNIST教程,我进行了一个实验( a href ="https://gist.github.com/prinsherbert/fa151010569e4dc18354296ba4f14efc" rel ="nofollow noreferrer">要点),以了解不同权重初始化对学习的影响.我注意到,与我在流行的 [Xavier,Glorot 2010]论文中阅读的内容相反,无论体重初始化如何,学习都很好.

Regarding the MNIST tutorial on the TensorFlow website, I ran an experiment (gist) to see what the effect of different weight initializations would be on learning. I noticed that, against what I read in the popular [Xavier, Glorot 2010] paper, learning is just fine regardless of weight initialization.

不同的曲线表示w的不同值,用于初始化卷积层和完全连接层的权重.请注意,即使0.31.0最终以较低的性能运行并且某些值训练得更快,所有w的值都可以正常工作-特别是0.030.1最快.尽管如此,该图显示了w的有效范围很大,表明w的``稳健性''.重量初始化.

The different curves represent different values for w for initializing the weights of the convolutional and fully connected layers. Note that all values for w work fine, even though 0.3 and 1.0 end up at lower performance and some values train faster - in particular, 0.03 and 0.1 are fastest. Nevertheless, the plot shows a rather large range of w which works, suggesting 'robustness' w.r.t. weight initialization.

def weight_variable(shape, w=0.1):
  initial = tf.truncated_normal(shape, stddev=w)
  return tf.Variable(initial)

def bias_variable(shape, w=0.1):
  initial = tf.constant(w, shape=shape)
  return tf.Variable(initial)

问题:为什么该网络不会遭受梯度消失或爆炸的困扰?

Question: Why does this network not suffer from the vanishing or exploding gradient problem?

我建议您阅读实施要点的要点,但这是供参考的代码.我的Nvidia 960m花了大约一个小时,尽管我想它也可以在合理时间内在CPU上运行.

I would suggest you read the gist for implementation details, but here's the code for reference. It took approximately an hour on my Nvidia 960m, although I imagine it could also run on a CPU within reasonable time.

import time
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.client import device_lib

import numpy
import matplotlib.pyplot as pyplot

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Weight initialization

def weight_variable(shape, w=0.1):
  initial = tf.truncated_normal(shape, stddev=w)
  return tf.Variable(initial)

def bias_variable(shape, w=0.1):
  initial = tf.constant(w, shape=shape)
  return tf.Variable(initial)


# Network architecture

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                    strides=[1, 2, 2, 1], padding='SAME')

def build_network_for_weight_initialization(w):
    """ Builds a CNN for the MNIST-problem:
     - 32 5x5 kernels convolutional layer with bias and ReLU activations
     - 2x2 maxpooling
     - 64 5x5 kernels convolutional layer with bias and ReLU activations
     - 2x2 maxpooling
     - Fully connected layer with 1024 nodes + bias and ReLU activations
     - dropout
     - Fully connected softmax layer for classification (of 10 classes)

     Returns the x, and y placeholders for the train data, the output
     of the network and the dropbout placeholder as a tuple of 4 elements.
    """
    x = tf.placeholder(tf.float32, shape=[None, 784])
    y_ = tf.placeholder(tf.float32, shape=[None, 10])

    x_image = tf.reshape(x, [-1,28,28,1])
    W_conv1 = weight_variable([5, 5, 1, 32], w)
    b_conv1 = bias_variable([32], w)

    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)
    W_conv2 = weight_variable([5, 5, 32, 64], w)
    b_conv2 = bias_variable([64], w)

    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    W_fc1 = weight_variable([7 * 7 * 64, 1024], w)
    b_fc1 = bias_variable([1024], w)

    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

    keep_prob = tf.placeholder(tf.float32)
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    W_fc2 = weight_variable([1024, 10], w)
    b_fc2 = bias_variable([10], w)

    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

    return (x, y_, y_conv, keep_prob)


# Experiment

def evaluate_for_weight_init(w):
    """ Returns an accuracy learning curve for a network trained on
    10000 batches of 50 samples. The learning curve has one item
    every 100 batches."""
    with tf.Session() as sess:
        x, y_, y_conv, keep_prob = build_network_for_weight_initialization(w)
        cross_entropy = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
        train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
        correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        sess.run(tf.global_variables_initializer())
        lr = []
        for _ in range(100):
            for i in range(100):
                batch = mnist.train.next_batch(50)
                train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
            assert mnist.test.images.shape[0] == 10000
            # This way the accuracy-evaluation fits in my 2GB laptop GPU.
            a = sum(
                accuracy.eval(feed_dict={
                    x: mnist.test.images[2000*i:2000*(i+1)],
                    y_: mnist.test.labels[2000*i:2000*(i+1)],
                    keep_prob: 1.0})
                for i in range(5)) / 5
            lr.append(a)
        return lr


ws = [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.0]
accuracies = [
    [evaluate_for_weight_init(w) for w in ws]
    for _ in range(3)
]


# Plotting results

pyplot.plot(numpy.array(accuracies).mean(0).T)
pyplot.ylim(0.9, 1)
pyplot.xlim(0,140)
pyplot.xlabel('batch (x 100)')
pyplot.ylabel('test accuracy')
pyplot.legend(ws)

推荐答案

重量初始化策略是改进模型的重要步骤,常常被忽略,由于这是Google的头等大事,我认为它可以为您带来更多收益详细的答案.

Weight initialization strategies can be an important and often overlooked step in improving your model, and since this is now the top result on Google I thought it could warrant a more detailed answer.

通常,每个层的激活函数梯度,传入/传出连接的数量(fan_in/fan_out)和权重的总积应等于1.这样,当您通过网络反向传播时,输入和输出梯度之间的方差将保持一致,并且不会遭受梯度爆炸或消失的困扰.即使ReLU更能抵抗爆炸/消失的梯度,您仍然可能会遇到问题.

In general, the total product of each layer's activation function gradient, number of incoming/outgoing connections (fan_in/fan_out), and variance of weights should be equal to one. This way, as you backpropagate through the network the variance between input and output gradients will stay consistent, and you won't suffer from exploding or vanishing gradients. Even though ReLU is more resistant to exploding/vanishing gradients, you might still have problems.

tf.truncated_normal进行随机初始化,以鼓励不同"地更新权重,但会考虑上述优化策略.在较小的网络上,这可能不是问题,但是如果您想要更深的网络或更快的训练时间,则最好根据最近的研究尝试权重初始化策略.

tf.truncated_normal used by OP does a random initialization which encourages weights to be updated "differently", but does not take the above optimization strategy into account. On smaller networks this might not be a problem, but if you want deeper networks, or faster training times, then you are best trying a weight initialization strategy based on recent research.

对于ReLU功能之前的砝码,您可以使用默认设置:

For weights preceding a ReLU function you could use the default settings of:

tf.contrib.layers.variance_scaling_initializer

tf.contrib.layers.variance_scaling_initializer

对于tanh/Sigmoid激活层,"xavier"可能更合适:

for tanh/sigmoid activated layers "xavier" might be more appropriate:

tf.contrib.layers.xavier_initializer

tf.contrib.layers.xavier_initializer

有关这些功能和相关论文的更多详细信息,请参见: https://www.tensorflow.org/versions/r0 .12/api_docs/python/contrib.layers/initializers

More details on both these functions and associated papers can be found at: https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.layers/initializers

除了权重初始化策略之外,进一步的优化还可以探索批量标准化: https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization

Beyond weight initialization strategies, further optimization could explore batch normalization: https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization

这篇关于Tensorflow权重初始化的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆