Tensorflow:从图中删除节点 [英] Tensorflow: delete nodes from graph
本文介绍了Tensorflow:从图中删除节点的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我正在尝试从图中删除一些节点并将其保存在 .pb 中
I'm trying to delete some nodes from graph and save it in .pb
只能将需要的节点添加到新的mod_graph_def
图中,但问题是图仍然有一些对其他节点输入中已删除节点的引用,但我无法修改节点的输入:>
Only needed nodes can be added to new mod_graph_def
graph, but the problem that graph still have some references to deleted node in other nodes inputs, but I can't modify inputs of node:
def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)
mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)
# The problem that graph still have some references to deleted node in other nodes inputs
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)
node.input = inp_names # TypeError: Can't set composite field
with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())
推荐答案
def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Delete nodes
nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)
mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)
# Delete references to deleted nodes
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)
del node.input[:]
node.input.extend(inp_names)
with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())
这篇关于Tensorflow:从图中删除节点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文