复制张量流图 [英] duplicate a tensorflow graph

查看:28
本文介绍了复制张量流图的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

复制 TensorFlow 图并使其保持最新的最佳方法是什么?

What is the best way of duplicating a TensorFlow graph and keep it uptodate?

理想情况下,我想将复制的图形放在另一个设备上(例如从 GPU 到 CPU),然后不时更新副本.

Ideally I want to put the duplicated graph on another device (e.g. from GPU to CPU) and then time to time update the copy.

推荐答案

简短回答:您可能想要 检查点文件 (永久链接).

长答案:

让我们弄清楚这里的设置.我假设您有两台设备 A 和 B,并且您正在 A 上训练并在 B 上运行推理.您希望定期使用在另一台设备上训练期间发现的新参数更新运行推理的设备上的参数.上面链接的教程是一个很好的起点.它向您展示了 tf.train.Saver 对象是如何工作的,您在这里不需要任何更复杂的东西.

Let's be clear about the setup here. I'll assume that you have two devices, A and B, and you are training on A and running inference on B. Periodically, you'd like to update the parameters on the device running inference with new parameters found during training on the other. The tutorial linked above is a good place to start. It shows you how tf.train.Saver objects work, and you shouldn't need anything more complicated here.

这是一个例子:

import tensorflow as tf

def build_net(graph, device):
  with graph.as_default():
    with graph.device(device):
      # Input placeholders
      inputs = tf.placeholder(tf.float32, [None, 784])
      labels = tf.placeholder(tf.float32, [None, 10])
      # Initialization
      w0 = tf.get_variable('w0', shape=[784,256], initializer=tf.contrib.layers.xavier_initializer())
      w1 = tf.get_variable('w1', shape=[256,256], initializer=tf.contrib.layers.xavier_initializer())
      w2 = tf.get_variable('w2', shape=[256,10], initializer=tf.contrib.layers.xavier_initializer())
      b0 = tf.Variable(tf.zeros([256]))
      b1 = tf.Variable(tf.zeros([256]))
      b2 = tf.Variable(tf.zeros([10]))
      # Inference network
      h1  = tf.nn.relu(tf.matmul(inputs, w0)+b0)
      h2  = tf.nn.relu(tf.matmul(h1,w1)+b1)
      output = tf.nn.softmax(tf.matmul(h2,w2)+b2)
      # Training network
      cross_entropy = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(output), reduction_indices=[1]))
      optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)    
      # Your checkpoint function
      saver = tf.train.Saver()
      return tf.initialize_all_variables(), inputs, labels, output, optimizer, saver

训练计划代码:

def programA_main():
  from tensorflow.examples.tutorials.mnist import input_data
  mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  # Build training network on device A
  graphA = tf.Graph()
  init, inputs, labels, _, training_net, saver = build_net(graphA, '/cpu:0')
  with tf.Session(graph=graphA) as sess:
    sess.run(init)
    for step in xrange(1,10000):
      batch = mnist.train.next_batch(50)
      sess.run(training_net, feed_dict={inputs: batch[0], labels: batch[1]})
      if step%100==0:
        saver.save(sess, '/tmp/graph.checkpoint')
        print 'saved checkpoint'

...以及推理程序的代码:

...and code for an inference program:

def programB_main():
  from tensorflow.examples.tutorials.mnist import input_data
  mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  # Build inference network on device B
  graphB = tf.Graph()
  init, inputs, _, inference_net, _, saver = build_net(graphB, '/cpu:0')
  with tf.Session(graph=graphB) as sess:
    batch = mnist.test.next_batch(50)

    saver.restore(sess, '/tmp/graph.checkpoint')
    print 'loaded checkpoint'
    out = sess.run(inference_net, feed_dict={inputs: batch[0]})
    print out[0]

    import time; time.sleep(2)

    saver.restore(sess, '/tmp/graph.checkpoint')
    print 'loaded checkpoint'
    out = sess.run(inference_net, feed_dict={inputs: batch[0]})
    print out[1]

如果您启动训练程序,然后启动推理程序,您将看到推理程序产生两个不同的输出(来自同一个输入批次).这是因为它选择了训练程序检查点的参数.

If you fire up the training program and then the inference program, you'll see the inference program produces two different outputs (from the same input batch). This is a result of it picking up the parameters that the training program has checkpointed.

现在,这个程序显然不是你的终点.我们不进行任何真正的同步,您必须决定定期"对于检查点的含义.但这应该让您了解如何将参数从一个网络同步到另一个网络.

Now, this program obviously isn't your end point. We don't do any real synchronization, and you'll have to decide what "periodic" means with respect to checkpointing. But this should give you an idea of how to sync parameters from one network to another.

最后一个警告:这意味着这两个网络必然是确定性的.TensorFlow 中有已知的非确定性元素(例如,this),所以要小心如果您需要完全相同的答案.但这是在多台设备上运行的硬道理.

One final warning: this does not mean that the two networks are necessarily deterministic. There are known non-deterministic elements in TensorFlow (e.g., this), so be wary if you need exactly the same answer. But this is the hard truth about running on multiple devices.

祝你好运!

这篇关于复制张量流图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆