将元数据添加到 tensorflow 冻结图 pb [英] Add metadata to tensorflow frozen graph pb

查看:28
本文介绍了将元数据添加到 tensorflow 冻结图 pb的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

为了分享我们训练好的 tensorflow 网络,我们将图形冻结到一个 .pb 文件中.我们还创建了一个 xml 文件,其中包含一些元数据,例如输入张量和输出张量、要应用的预处理类型、训练数据信息等.然后通过加载图形和评估张量等使用 Java 或 C# 提供模型.

To share our trained tensorflow networks, we freeze the graph into a .pb file. We also create an xml file with some metadata such as the input tensors and output tensors, type of pre-processing to apply, training data information etc. The models are then served using Java or C# by loading the graph and evaluating the tensors etc.

为了使共享更容易,我想在 .pb 文件中的某处包含此 xml 数据.有没有办法做到这一点?一种想法是将其作为 tf.Constant,但我不知道如何将其连接到普通图.

To make sharing easier, I would like to include this xml data somewhere in the .pb file. Is there any way to do this? One idea would be to have it as a tf.Constant, but I don't see how I could connect it to the normal graph.

注意这是使用 freeze_graph.py.新的 SavedModel 格式是否更合适?

Note this is using freeze_graph.py. Is the new SavedModel format more suitable?

推荐答案

首先,是的,您应该使用新的 SavedModel 格式,因为它会得到 TF 团队的支持,并且也可以与 Keras 一起使用.您可以向模型添加一个额外的端点,它会返回一个带有 XML 数据字符串的常量张量(如您所述).

First of all, yes you should use the new SavedModel format, as it is what will be supported by the TF team going forwards, and works with Keras as well. You can add an additional endpoint to the model, that returns a constant tensor (as you mention) with a string of your XML data.

这很好,因为它是密封的——底层的保存模型格式并不重要,因为您的元数据保存在计算图本身中.

This is good because it's hermetic -- the underlying savemodel format does not matter, because your metadata is saved in the computation graph itself.

查看此问题的答案:保存 TF2 keras具有自定义签名 defs 的模型.对于 Keras,该答案并不能 100% 地为您提供帮助,因为它无法与 tf.keras.models.load 函数很好地互操作,因为它们将其包装在 tf.Module 中.幸运的是,如果添加 tf.function 装饰器,使用 tf.keras.Model 在 TF2 中也能正常工作:

See the answer to this question: Saving a TF2 keras model with custom signature defs . That answer doesn't get you 100% of the way there for Keras, because it doesn't interop nicely with the tf.keras.models.load function, as they wrap it inside a tf.Module. Luckily, using tf.keras.Model works as well in TF2, if you add a tf.function decorator:

class MyModel(tf.keras.Model):

  def __init__(self, metadata, **kwargs):
    super(MyModel, self).__init__(**kwargs)
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.metadata = tf.constant(metadata)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

  @tf.function(input_signature=[])
  def get_metadata(self):
    return self.metadata

model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)

然后您可以按如下方式保存和加载您的模型:

Then you can save and load your model as follows:

tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')

最后使用 model_loaded.get_metadata() 来检索您的常量元数据张量.

And finally use model_loaded.get_metadata() to retrieve your constant metadata tensor.

这篇关于将元数据添加到 tensorflow 冻结图 pb的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆