无法使用BatchNorm层导入冻结图 [英] Can't import frozen graph with BatchNorm layer

查看:115
本文介绍了无法使用BatchNorm层导入冻结图的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经基于 repo 训练了Keras模型.

I have trained a Keras model based on this repo.

训练后,我将模型保存为如下检查点文件:

After the training I save the model as checkpoint files like this:

 sess=tf.keras.backend.get_session() 
 saver = tf.train.Saver()
 saver.save(sess, current_run_path + '/checkpoint_files/model_{}.ckpt'.format(date))

然后,我从检查点文件中还原图形,并使用标准tf Frozen_graph脚本将其冻结.当我想恢复冻结的图时,出现以下错误:

Then I restore the graph from the checkpoint files and freeze it using the standard tf freeze_graph script. When I want to restore the frozen graph I get the following error:

Input 0 of node Conv_BN_1/cond/ReadVariableOp/Switch was passed float from Conv_BN_1/gamma:0 incompatible with expected resource

如何解决此问题?

我的问题与这个问题.不幸的是,我无法使用替代方法.

My problem is related to this question. Unfortunately, I can't use the workaround.

我在github上打开了一个问题,并创建了一个要点来重现该错误. https://github.com/keras-team/keras/issues/11032

Edit 2: I have opened an issue on github and created a gist to reproduce the error. https://github.com/keras-team/keras/issues/11032

推荐答案

只需解决相同的问题.我连接了以下几个答案: 1 3 ,并意识到该问题源自 batchnorm层的工作状态:培训或学习.因此,为了解决该问题,您只需要在加载模型之前放置一行即可:

Just resolved the same issue. I connected this few answers: 1, 2, 3 and realized that issue originated from batchnorm layer working state: training or learning. So, in order to resolve that issue you just need to place one line before loading your model:

keras.backend.set_learning_phase(0)

完整示例,以导出模型

import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3


def freeze_graph(graph, session, output):
    with graph.as_default():
        graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
        graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)

tf.keras.backend.set_learning_phase(0) # this line most important

base_model = InceptionV3()

session = tf.keras.backend.get_session()

INPUT_NODE = base_model.inputs[0].op.name
OUTPUT_NODE = base_model.outputs[0].op.name
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])

加载* .pb模型:

from PIL import Image
import numpy as np
import tensorflow as tf

# https://i.imgur.com/tvOB18o.jpg
im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
im = np.array(im) / 255.0
im = im[None, ...]

graph_def = tf.GraphDef()

with tf.gfile.GFile("frozen_model.pb", "rb") as f:
    graph_def.ParseFromString(f.read())

graph = tf.Graph()

with graph.as_default():
    net_inp, net_out = tf.import_graph_def(
        graph_def, return_elements=["input_1", "predictions/Softmax"]
    )
    with tf.Session(graph=graph) as sess:
        out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
        print(np.argmax(out))

这篇关于无法使用BatchNorm层导入冻结图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆