在 Tensorflow 2.x 中重新训练冻结图 [英] Retrain Frozen Graph in Tensorflow 2.x
问题描述
我根据 这个精彩的细节主题.基本上,描述了该方法:
I have managed this implementation on retraining frozen graph in tensorflow 1 according to this wonderful detail topic. Basically, the methodology is described:
- 加载冻结模型
- 用
可变节点
替换常量冻结节点
. - 新替换的变量节点将被重定向到冻结节点的相应输出.
通过检查 tf.compat.v1.trainable_variables
,这在 tensorflow 1.x 中有效.但是,在 tensorflow 2.x 中,它不能再工作了.
This works in tensorflow 1.x by checking the tf.compat.v1.trainable_variables
. However, in tensorflow 2.x, it can't work anymore.
以下是代码片段:
1/加载冻结模型
frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.graph_util.import_graph_def(od_graph_def, name='')
2/创建一个克隆
with detection_graph.as_default():
const_var_name_pairs = {}
probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
available_names = [op.name for op in detection_graph.get_operations()]
for op in probable_variables:
name = op.name
if name+'/read' not in available_names:
continue
tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
with tf.compat.v1.Session() as s:
tensor_as_numpy_array = s.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
const_var_name_pairs[name] = var_name
3/通过图表编辑器重新定位冻结节点
3/ Relace frozen node by Graph Editor
import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
const_op = name_to_op[const_name+'/read']
var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
writer.close
推荐答案
当我导入 tf.graph_def
而不是原始 时,问题出在我的
有变量.Graph Editor
>tf.graph
The problem was my Graph Editor
when I import the tf.graph_def
instead of the original tf.graph
that has Variables.
通过修复第 3 步快速解决
Quickly solve by fixing step 3
Sol1:使用图表编辑器
ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
const_op = ge_graph._node_name_to_node[const_name+'/read']
var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
然而,这需要禁用急切执行.要解决急切执行,您应该将 MetaGraphDef
附加到 Graph Editor
,如下所示
However, this requires disable eager execution. To work around with eager execution, you should attach the MetaGraphDef
to Graph Editor
as below
with detection_graph.as_default():
meta_saver = tf.compat.v1.train.Saver()
meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))
然而,这是使模型可在 tf2.x 中训练的最技巧我们应该自己导出,而不是使用 Graph Editor
直接导出图形.原因是 Graph Editor
将 Variables 数据类型设为 resources
.因此,我们应该将图导出为 graphdef 并将变量 def 导入到图中:
However, this is the trickest to make the model trainable in tf2.x
Instead of using Graph Editor
to export directly the graph, we should export ourselves. The reason is that the Graph Editor
make the Variables data type to be resources
. Therefore, we should export the graph as graphdef and import the variable def to the graph:
test_graph = tf.Graph()
with test_graph.as_default():
tf.import_graph_def(ge_graph.to_graph_def(), name="")
for var_name in ge_graph.variable_names:
var = ge_graph.get_variable_by_name(var_name)
ret = variable_pb2.VariableDef()
ret.variable_name = var._variable_name
ret.initial_value_name = var._initial_value_name
ret.initializer_name = var._initializer_name
ret.snapshot_name = var._snapshot_name
ret.trainable = var._trainable
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
test_graph.add_to_collections(var.collection_names, tf_var)
Sol2:通过 Graphdef 手动映射
Sol2: Manually map by Graphdef
with detection_graph.as_default() as graph:
training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
tf.graph_util.import_graph_def(training_graph_def, name='')
for var in current_var:
ret = variable_pb2.VariableDef()
ret.variable_name = var.name
ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
ret.initializer_name = var.name[:-2] + '/Assign'
ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
ret.trainable = True
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
这篇关于在 Tensorflow 2.x 中重新训练冻结图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!