在 TensorFlow 中重新训练冻结的 *.pb 模型 [英] Re-train a frozen *.pb model in TensorFlow

查看:44
本文介绍了在 TensorFlow 中重新训练冻结的 *.pb 模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何导入冻结的 protobuf 以启用它以进行重新训练?

How do I import a frozen protobuf to enable it for re-training?

我在网上找到的所有方法都需要检查点.有没有办法读取 protobuf 以便将内核和偏置常数转换为变量?

All the methods i've found online expect checkpoints. Is there a way to read a protobuf such that kernel and bias constants are converted to variables?

编辑 1:这类似于以下问题:How to retrain model in graph (.pb)?

我查看了该问题的答案中推荐的 DeepSpeech.他们似乎有 .我找不到原因.

I looked at DeepSpeech, which was recommended in the answers to that question. They seem to have removed support for initialize_from_frozen_model. I couldn't find the reason.

编辑 2:我尝试创建一个新的 GraphDef 对象,在其中用变量替换内核和偏差:

Edit 2: I tried creating a new GraphDef object where I replace the kernels and biases with Variables:

probable_variables = [...] # kernels and biases of Conv2D and MatMul

new_graph_def = tf.GraphDef()

with tf.Session(graph=graph) as sess:
    for n in sess.graph_def.node:

        if n.name in probable_variables:
            # create variable op
            nn = new_graph_def.node.add()
            nn.name = n.name
            nn.op = 'VariableV2'
            nn.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
            nn.attr['shape'].CopyFrom(attr_value_pb2.AttrValue(shape=shape))

        else:
            nn = new_model.node.add()
            nn.CopyFrom(n)

不确定我是否在正确的道路上.不知道如何在 NodeDef 对象中设置 trainable=True.

Not sure if I am on the right path. Don't know how to set trainable=True in a NodeDef object.

推荐答案

你提供的代码片段实际上是在正确的方向:)

最棘手的部分是获取以前可训练变量的名称.希望该模型是使用一些高级框架创建的,例如 kerastf.slim - 他们很好地包装了他们的变量类似 conv2d_1/kerneldense_1/biasbatch_normalization/gamma

The most tricky part is to get the names of previously trainable variables. Hopefully the model was created with some high-level frameworks, like keras or tf.slim - they wraps their variables nicely in something like conv2d_1/kernel, dense_1/bias, batch_normalization/gamma, etc.

如果您不确定,最有用的方法是将图形可视化...

If you're not sure, the most useful thing to do is to visualize the graph...

# read graph definition
with tf.gfile.GFile('frozen.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# now build the graph in the memory and visualize it
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="prefix")
    writer = tf.summary.FileWriter('out', graph)
    writer.close()

... 使用张量板:

$ tensorboard --logdir out/

并亲眼看看图表是什么样子以及命名是什么.

and see for yourself what the graph looks like and what the naming is.

您只需要一个名为 tf.contrib.graph_editor的神奇库代码>.现在假设您已经在 probable_variables(如在您的编辑 2).

All you need is the magical library called tf.contrib.graph_editor. Now let's say you've stored the names of previously trainable ops (that previously were variables but now they are Const) in probable_variables (as in your Edit 2).

注意:记住 opstensorsvariables 之间的区别.ops 是图的元素,tensor 是一个包含 ops 结果的缓冲区,变量是 围绕张量的包装器,有 3 个操作:assign(在初始化变量时调用),read(由调用其他操作,例如 conv2d) 和 ref 张量(保存值).

Note: remember the difference between ops, tensors, and variables. Ops are elements of the graph, tensor is a buffer that contains results of ops, and variables are wrappers around tensors, with 3 ops: assign (to be called when you initialize the variable), read (called by other ops, e.g. conv2d), and ref tensor (which holds the values).

注意 2:graph_editor只能在会话外运行 –您不能在线进行任何图形修改!

Note 2: graph_editor can only be run outside a session – you cannot make any graph modification online!

import numpy as np
import tensorflow.contrib.graph_editor as ge

# load the graphdef into memory, just as in Step 1
graph = load_graph('frozen.pb')

# create a variable for each constant, beware the naming
const_var_name_pairs = []
for name in probable_variables:
    var_shape = graph.get_tensor_by_name('{}:0'.format(name)).get_shape()
    var_name = '{}_a'.format(name)
    var = tf.get_variable(name=var_name, shape=var_shape, dtype='float32')
    const_var_name_pairs.append((name, var_name))

# from now we're going to work with GraphDef
name_to_op = dict([(n.name, n) for n in graph.as_graph_def().node])

# magic: now we swap the outputs of const and created variable
for const_name, var_name in const_var_name_pairs:
    const_op = name_to_op[const_name]
    var_reader_op = name_to_op[var_name + '/read']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

# Now we can safely create a session and copy the values
sess = tf.Session(graph=graph)
for const_name, var_name in const_var_name_pairs:
    ts = graph.get_tensor_by_name('{}:0'.format(const_name))
    var = tf.get_variable(var_name)
    var.load(ts.eval(sess))

# All done! Now you can make sure everything is correct by visualizing
# and calculate outputs for some inputs.

PS:此代码未经测试;然而,我最近一直在使用 graph_editor 并经常进行网络手术,所以我认为它应该大部分是正确的:)

PS: this code was not tested; however, i've been using graph_editor and performing network surgery quite often lately, so I think it should mostly be correct :)

这篇关于在 TensorFlow 中重新训练冻结的 *.pb 模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆