TensorFlow:有没有办法将冻结图转换为检查点模型? [英] TensorFlow: Is there a way to convert a frozen graph into a checkpoint model?

查看:27
本文介绍了TensorFlow:有没有办法将冻结图转换为检查点模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

可以将检查点模型转换为冻结图(.ckpt 文件到 .pb 文件).但是,有没有一种反向的方法可以将 pb 文件再次转换为检查点文件?

Converting a checkpoint model into a frozen graph is possible (.ckpt file to .pb file). However, is there a reverse method of converting a pb file into a checkpoint file once again?

我想它需要将常量转换回变量 - 有没有办法将正确的常量识别为变量并将它们恢复回检查点模型?

I'd imagine it requires a conversion of the constants back into a variable - is there a way to identify the correct constants as variables and restore them back into a checkpoint model?

目前支持将变量转换为常量:https://www.tensorflow.org/api_docs/python/tf/graph_util/convert_variables_to_constants

Currently there is support for conversion of variables to constants here: https://www.tensorflow.org/api_docs/python/tf/graph_util/convert_variables_to_constants

但不是反过来.

这里提出了一个类似的问题:Tensorflow:将常数张量从预训练的 Vgg 模型转换为变量

A similar question has been raised here: Tensorflow: Convert constant tensor from pre-trained Vgg model to variable

但解决方案依赖于使用 ckpt 模型来恢复权重变量.有没有办法从 PB 文件而不是检查点文件中恢复权重变量?这对于权重修剪很有用.

But the solution relies on using a ckpt model to restore weight variables. Is there a way to restore weight variables from PB files instead of a checkpoint file? This could be useful for weight pruning.

推荐答案

一种通过图形编辑器将常量转换回 TensorFlow 中可训练变量的方法.但是,您需要指定要转换的节点,因为我不确定是否有办法以稳健的方式自动检测到这一点.

There is a method for converting constants back to trainable variables in TensorFlow, via the Graph Editor. However, you will need to specify the nodes to convert, as I'm not sure if there is a way to automatically detect this in a robust manner.

步骤如下:

我们将 .pb 文件加载到图形对象中.

We load our .pb file into a graph object.

import tensorflow as tf

# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

tf_graph = load_pb('frozen_graph.pb')

第 2 步:查找需要转换的常量

这里有两种列出图中节点名称的方法:

Step 2: Find constants that need conversion

Here are 2 ways to list the names of nodes in your graph:

  • 使用这个脚本打印它们
  • print([n.name for n in tf_graph.as_graph_def().node])

您要转换的节点很可能按照Const"的方式命名.可以肯定的是,最好在 Netron 中加载您的图形以查看哪些张量正在存储可训练的权重.通常,可以安全地假设所有 const 节点都曾经是变量.

The nodes you'll want to convert are likely named something along the lines of "Const". To be sure, it is a good idea to load your graph in Netron to see which tensors are storing the trainable weights. Oftentimes, it is safe to assume that all const nodes were once variables.

确定这些节点后,让我们将它们的名称存储到列表中:

Once you have these nodes identified, let's store their names into a list:

to_convert = [...] # names of tensors to convert

步骤 3:将常量转换为变量

运行此代码以转换您指定的常量.它本质上为每个常量创建了相应的变量,并使用 GraphEditor 将常量从图中解开,然后将变量挂接到上.

Step 3: Convert constants to variables

Run this code to convert your specified constants. It essentially creates corresponding variables for each constant and uses GraphEditor to unhook the constants from the graph, and hook the variables on.

import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge

const_var_name_pairs = []
with tf_graph.as_default() as g:

    for name in to_convert:
        tensor = g.get_tensor_by_name('{}:0'.format(name))
        with tf.Session() as sess:
            tensor_as_numpy_array = sess.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        # Create TensorFlow variable initialized by values of original const.
        var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \  
                      initializer=tf.constant_initializer(tensor_as_numpy_array))
        # We want to keep track of our variables names for later.
        const_var_name_pairs.append((name, var_name))

    # At this point, we added a bunch of tf.Variables to the graph, but they're
    # not connected to anything.

    # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
    # the outputs of our newly created Variables.

    for const_name, var_name in const_var_name_pairs:
        const_op = g.get_operation_by_name(const_name)
        var_reader_op = g.get_operation_by_name(var_name + '/read')
        ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

第 4 步:将结果保存为 .ckpt

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        save_path = tf.train.Saver().save(sess, 'model.ckpt')
        print("Model saved in path: %s" % save_path)

还有中提琴!你应该在这一点上完成 :) 我能够自己让它工作,并验证模型权重被保留——唯一的区别是图现在是可训练的.如果有任何问题,请告诉我.

And viola! You should be done at this point :) I was able to get this working myself, and verified that the model weights are preserved--the only difference is that the graph is now trainable. Please let me know if there are any issues.

这篇关于TensorFlow:有没有办法将冻结图转换为检查点模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆