使用 TF1 读取用 TF2 创建的 protobuf [英] reading a protobuf created with TF2 using TF1

查看:66
本文介绍了使用 TF1 读取用 TF2 创建的 protobuf的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个存储为 hdf5 的模型,我使用 saved_model.save 将其导出到 protobuf (PB) 文件,如下所示:

I have a model stored as an hdf5 which I export to a protobuf (PB) file using saved_model.save, like this:

from tensorflow import keras
import tensorflow as tf
model = keras.models.load_model("model.hdf5")
tf.saved_model.save(model, './output_dir/')

这很好用,结果是一个 saved_model.pb 文件,我以后可以用其他软件查看它,没有问题.

this works fine and the result is a saved_model.pb file which I can later view with other software with no issues.

但是,当我尝试使用 TensorFlow1 导入这个 PB 文件时,我的代码失败了.由于 PB 应该是一种通用格式,这让我很困惑.

However, when I try to import this PB file using TensorFlow1, my code fails. As PB is supposed to be a universal format, this confuses me.

我用来读取PB文件的代码是这样的:

The code I use to read the PB file is this:

import tensorflow as tf
curr_graph = tf.Graph()
curr_sess = tf.InteractiveSession(graph=curr_graph)
f = tf.gfile.GFile('model.hdf5','rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
f.close()

这是我得到的例外:

回溯(最近一次调用最后一次):文件read_pb.py",第 14 行,在graph_def.ParseFromString(f.read()) google.protobuf.message.DecodeError:解析消息时出错

Traceback (most recent call last): File "read_pb.py", line 14, in graph_def.ParseFromString(f.read()) google.protobuf.message.DecodeError: Error parsing message

我有一个不同的模型存储为 PB 文件,读取代码可以正常工作.

I have a different model stored as a PB file on which the reading code works fine.

怎么回事?

***** 编辑 1 *****

***** EDIT 1 *****

在使用下面 Andrea Angeli 的代码时,我遇到了以下错误:

While using Andrea Angeli's code below, I've encountered the following error:

遇到错误:NodeDef 未提及 attr 'exponential_avg_factor'在 Op y:T、batch_mean:U、batch_variance:U、Reserve_space_1:U、reserve_space_2:U、reserve_space_3:U;attr=T:type,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT];attr=U:type,allowed=[DT_FLOAT];attr=epsilon:float,default=0.0001;attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW"];attr=is_training:bool,default=true>;节点定义:{节点u-mobilenetv2/bn_Conv1/FusedBatchNormV3}.(检查您的GraphDef 解释二进制文件是最新的GraphDef 生成二进制文件.).

Encountered Error: NodeDef mentions attr 'exponential_avg_factor' not in Op y:T, batch_mean:U, batch_variance:U, reserve_space_1:U, reserve_space_2:U, reserve_space_3:U; attr=T:type,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT]; attr=U:type,allowed=[DT_FLOAT]; attr=epsilon:float,default=0.0001; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW"]; attr=is_training:bool,default=true>; NodeDef: {node u-mobilenetv2/bn_Conv1/FusedBatchNormV3}. (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).

是否有解决方法?

推荐答案

您正在尝试读取 hdf5 文件,而不是使用 tf.saved_model.save(..) 保存的 protobuf 文件.另请注意,TF2 导出的 protobuf 与 TF 1 的冻结图不同,因为它仅包含计算图.

You are trying to read the hdf5 file and not the protobuf file you saved with tf.saved_model.save(..). Also beware, the TF2 exported protobuf is not the same as TF 1's frozen graph as it only contains the computation graph.

编辑 1:如果要从 TF 2 模型导出 TF 1 样式的冻结图,可以使用以下代码段完成:

Edit 1: If you want to export a TF 1 styled frozen graph from a TF 2 model, it can be done using the following code snippet:

from tensorflow.python.framework import convert_to_constants

def export_to_frozen_pb(model: tf.keras.models.Model, path: str) -> None:
    """
    Creates a frozen graph from a keras model.

    Turns the weights of a model into constants and saves the resulting graph into a protobuf file.

    Args:
        model: tf.keras.Model to convert into a frozen graph
        path: Path to save the profobuf file
    """
    inference_func = tf.function(lambda input: model(input))

    concrete_func = inference_func.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
    output_func = convert_to_constants.convert_variables_to_constants_v2(concrete_func)

    graph_def = output_func.graph.as_graph_def()
    graph_def.node[-1].name = 'output'

    with open(os.path.join(path, 'saved_model.pb'), 'wb') as freezed_pb:
        freezed_pb.write(graph_def.SerializeToString())

这将在您在 path 参数中指定的位置生成一个 protobuf 文件 (saved_model.pb).您的图形的输入节点将具有名称input:0"(这是由 lambda 实现的)和输出节点output:0".

This will result in a protobuf file (saved_model.pb) at the location you specify in path param. Your graph's input node will have the name "input:0" (this is achieved by the lambda) and the output node "output:0".

这篇关于使用 TF1 读取用 TF2 创建的 protobuf的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆