如何将Keras .h5导出到tensorflow .pb? [英] How to export Keras .h5 to tensorflow .pb?

查看:143
本文介绍了如何将Keras .h5导出到tensorflow .pb?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经用新的数据集对初始模型进行了微调,并将其保存为Keras中的".h5"模型.现在我的目标是在仅接受".pb"扩展名的android Tensorflow上运行我的模型.问题是在Keras或tensorflow中是否有任何库可以进行此转换?到目前为止,我已经看过这篇文章: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html ,但目前还不清楚.

I have fine-tuned inception model with a new dataset and saved it as ".h5" model in Keras. now my goal is to run my model on android Tensorflow which accepts ".pb" extension only. question is that is there any library in Keras or tensorflow to do this conversion? I have seen this post so far : https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html but can't figure out yet.

推荐答案

Keras本身不包含任何将TensorFlow图导出为协议缓冲区文件的方法,但是您可以使用常规TensorFlow实用程序来实现. 此处是一篇博客文章,解释了如何使用实用程序脚本 freeze_graph.py ,这是完成操作的典型"方法.

Keras does not include by itself any means to export a TensorFlow graph as a protocol buffers file, but you can do it using regular TensorFlow utilities. Here is a blog post explaining how to do it using the utility script freeze_graph.py included in TensorFlow, which is the "typical" way it is done.

但是,我个人认为必须创建一个检查点,然后运行外部脚本来获取模型,但我更喜欢通过我自己的Python代码来执行此操作,所以我使用了这样的函数:

However, I personally find a nuisance having to make a checkpoint and then run an external script to obtain a model, and instead prefer to do it from my own Python code, so I use a function like this:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph

freeze_graph.py的实现受到启发.参数也类似于脚本. session是TensorFlow会话对象.仅当您要保持某些变量不冻结时才需要keep_var_names(例如对于有状态模型),通常不需要. output_names是带有产生所需输出的操作名称的列表. clear_devices只是删除任何设备指令以使图形更易于移植.因此,对于具有一个输出的典型Keras model,您将执行以下操作:

Which is inspired in the implementation of freeze_graph.py. The parameters are similar to the script too. session is the TensorFlow session object. keep_var_names is only needed if you want to keep some variable not frozen (e.g. for stateful models), so generally not. output_names is a list with the names of the operations that produce the outputs that you want. clear_devices just removes any device directives to make the graph more portable. So, for a typical Keras model with one output, you would do something like:

from keras import backend as K

# Create, compile and train model...

frozen_graph = freeze_session(K.get_session(),
                              output_names=[out.op.name for out in model.outputs])

然后,您可以使用 tf.train.write_graph :

Then you can write the graph to a file as usual with tf.train.write_graph:

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)

这篇关于如何将Keras .h5导出到tensorflow .pb?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆