检查 tensorflow keras 模型中的下一层 [英] Check which are the next layers in a tensorflow keras model
问题描述
我有一个 keras 模型,它有快捷方式层之间.对于每一层,我想获得下一个连接层的名称(或索引),因为简单地遍历所有 model.layers
不会告诉我该层是否连接到前一层与否.
I have a keras model which has shortcuts between layers. For each layer, I would like to get the name (or index) of the next connected layers, because simply iterating through all the model.layers
will not tell me whether the layer was connected to the previous one or not.
示例模型可以是:
model = tf.keras.applications.resnet50.ResNet50(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None, classes=1000)
推荐答案
你可以这样提取dict
格式的信息...
You can extract the information in dict
format in this way...
首先,定义一个效用函数并从每个 Functional
模型(代码参考>
Firstly, define a utility function and get the relevant nodes as made in the model.summary()
method from every Functional
model (code reference)
relevant_nodes = []
for v in model._nodes_by_depth.values():
relevant_nodes += v
def get_layer_summary_with_connections(layer):
info = {}
connections = []
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
connections.append(inbound_layer.name)
name = layer.name
info['type'] = layer.__class__.__name__
info['parents'] = connections
return info
其次,通过层层迭代提取信息:
Secondly, extract the information iterating through layers:
results = {}
layers = model.layers
for layer in layers:
info = get_layer_summary_with_connections(layer)
results[layer.name] = info
results
是具有以下格式的嵌套 dict
:
results
is a nested dict
with this format:
{
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'},
...
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'}
}
对于 ResNet50
它导致:
{
'input_4': {'type': 'InputLayer', 'parents': []},
'conv1_pad': {'type': 'ZeroPadding2D', 'parents': ['input_4']},
'conv1_conv': {'type': 'Conv2D', 'parents': ['conv1_pad']},
'conv1_bn': {'type': 'BatchNormalization', 'parents': ['conv1_conv']},
...
'conv5_block3_out': {'type': 'Activation', 'parents': ['conv5_block3_add']},
'avg_pool': {'type': 'GlobalAveragePooling2D', 'parents' ['conv5_block3_out']},
'predictions': {'type': 'Dense', 'parents': ['avg_pool']}
}
另外,你可以修改get_layer_summary_with_connections
返回你感兴趣的所有信息
Also, you can modify get_layer_summary_with_connections
to return all the information you are interested in
这篇关于检查 tensorflow keras 模型中的下一层的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!