无法使用BatchNorm层导入冻结图 [英] Can't import frozen graph with BatchNorm layer
问题描述
我已经基于 repo 训练了Keras模型.
I have trained a Keras model based on this repo.
训练后,我将模型保存为如下检查点文件:
After the training I save the model as checkpoint files like this:
sess=tf.keras.backend.get_session()
saver = tf.train.Saver()
saver.save(sess, current_run_path + '/checkpoint_files/model_{}.ckpt'.format(date))
然后,我从检查点文件中还原图形,并使用标准tf Frozen_graph脚本将其冻结.当我想恢复冻结的图时,出现以下错误:
Then I restore the graph from the checkpoint files and freeze it using the standard tf freeze_graph script. When I want to restore the frozen graph I get the following error:
Input 0 of node Conv_BN_1/cond/ReadVariableOp/Switch was passed float from Conv_BN_1/gamma:0 incompatible with expected resource
如何解决此问题?
我的问题与这个问题.不幸的是,我无法使用替代方法.
My problem is related to this question. Unfortunately, I can't use the workaround.
我在github上打开了一个问题,并创建了一个要点来重现该错误. https://github.com/keras-team/keras/issues/11032
Edit 2: I have opened an issue on github and created a gist to reproduce the error. https://github.com/keras-team/keras/issues/11032
推荐答案
只需解决相同的问题.我连接了以下几个答案: 1 , 3 ,并意识到该问题源自 batchnorm层的工作状态:培训或学习.因此,为了解决该问题,您只需要在加载模型之前放置一行即可:
Just resolved the same issue. I connected this few answers: 1, 2, 3 and realized that issue originated from batchnorm layer working state: training or learning. So, in order to resolve that issue you just need to place one line before loading your model:
keras.backend.set_learning_phase(0)
完整示例,以导出模型
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3
def freeze_graph(graph, session, output):
with graph.as_default():
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)
tf.keras.backend.set_learning_phase(0) # this line most important
base_model = InceptionV3()
session = tf.keras.backend.get_session()
INPUT_NODE = base_model.inputs[0].op.name
OUTPUT_NODE = base_model.outputs[0].op.name
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])
加载* .pb模型:
from PIL import Image
import numpy as np
import tensorflow as tf
# https://i.imgur.com/tvOB18o.jpg
im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
im = np.array(im) / 255.0
im = im[None, ...]
graph_def = tf.GraphDef()
with tf.gfile.GFile("frozen_model.pb", "rb") as f:
graph_def.ParseFromString(f.read())
graph = tf.Graph()
with graph.as_default():
net_inp, net_out = tf.import_graph_def(
graph_def, return_elements=["input_1", "predictions/Softmax"]
)
with tf.Session(graph=graph) as sess:
out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
print(np.argmax(out))
这篇关于无法使用BatchNorm层导入冻结图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!