在 Tensorflow 2.x 中重新训练冻结图 [英] Retrain Frozen Graph in Tensorflow 2.x

查看:92
本文介绍了在 Tensorflow 2.x 中重新训练冻结图的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我根据 这个精彩的细节主题.基本上,描述了该方法:

I have managed this implementation on retraining frozen graph in tensorflow 1 according to this wonderful detail topic. Basically, the methodology is described:

  1. 加载冻结模型
  2. 可变节点替换常量冻结节点.
  3. 新替换的变量节点将被重定向到冻结节点的相应输出.

通过检查 tf.compat.v1.trainable_variables,这在 tensorflow 1.x 中有效.但是,在 tensorflow 2.x 中,它不能再工作了.

This works in tensorflow 1.x by checking the tf.compat.v1.trainable_variables. However, in tensorflow 2.x, it can't work anymore.

以下是代码片段:

1/加载冻结模型

frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.compat.v1.GraphDef()
    with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.graph_util.import_graph_def(od_graph_def, name='')

2/创建一个克隆

with detection_graph.as_default():
    const_var_name_pairs = {}
    probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
    available_names = [op.name for op in detection_graph.get_operations()]
    for op in probable_variables:
        name = op.name
        if name+'/read' not in available_names:
            continue
        tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
        with tf.compat.v1.Session() as s:
            tensor_as_numpy_array = s.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
        const_var_name_pairs[name] =  var_name

3/通过图表编辑器重新定位冻结节点

3/ Relace frozen node by Graph Editor

import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
    const_op = name_to_op[const_name+'/read']
    var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
    writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
    writer.close

推荐答案

当我导入 tf.graph_def 而不是原始 时,问题出在我的 Graph Editor>tf.graph 有变量.

The problem was my Graph Editor when I import the tf.graph_def instead of the original tf.graph that has Variables.

通过修复第 3 步快速解决

Quickly solve by fixing step 3

Sol1:使用图表编辑器

ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
    const_op = ge_graph._node_name_to_node[const_name+'/read']
    var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

然而,这需要禁用急切执行.要解决急切执行,您应该将 MetaGraphDef 附加到 Graph Editor,如下所示

However, this requires disable eager execution. To work around with eager execution, you should attach the MetaGraphDef to Graph Editor as below

with detection_graph.as_default():
    meta_saver = tf.compat.v1.train.Saver()
    meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))

然而,这是使模型可在 tf2.x 中训练的最技巧我们应该自己导出,而不是使用 Graph Editor 直接导出图形.原因是 Graph Editor 将 Variables 数据类型设为 resources.因此,我们应该将图导出为 graphdef 并将变量 def 导入到图中:

However, this is the trickest to make the model trainable in tf2.x Instead of using Graph Editor to export directly the graph, we should export ourselves. The reason is that the Graph Editor make the Variables data type to be resources. Therefore, we should export the graph as graphdef and import the variable def to the graph:

test_graph = tf.Graph()
with test_graph.as_default():
    tf.import_graph_def(ge_graph.to_graph_def(), name="")
    for var_name in ge_graph.variable_names:
        var = ge_graph.get_variable_by_name(var_name)
        ret = variable_pb2.VariableDef()
        ret.variable_name = var._variable_name
        ret.initial_value_name = var._initial_value_name
        ret.initializer_name = var._initializer_name
        ret.snapshot_name = var._snapshot_name
        ret.trainable = var._trainable
        ret.is_resource = True
        tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
        test_graph.add_to_collections(var.collection_names, tf_var)

Sol2:通过 Graphdef 手动映射

Sol2: Manually map by Graphdef

with detection_graph.as_default() as graph:
    training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
    current_var = (tf.compat.v1.trainable_variables())
    assert len(current_var)>0, "no training variables"


detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
    tf.graph_util.import_graph_def(training_graph_def, name='')
    for var in current_var:
        ret = variable_pb2.VariableDef()
        ret.variable_name = var.name
        ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
        ret.initializer_name = var.name[:-2] + '/Assign'
        ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
        ret.trainable = True
        ret.is_resource = True
        tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
        detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
    current_var = (tf.compat.v1.trainable_variables())
    assert len(current_var)>0, "no training variables"

这篇关于在 Tensorflow 2.x 中重新训练冻结图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆