TF2.0 中的 saved_model.prune() [英] saved_model.prune() in TF2.0

查看:55
本文介绍了TF2.0 中的 saved_model.prune()的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试修剪使用 tf.keras 生成的 SavedModel 的节点.剪枝脚本如下:

I am trying to prune nodes of a SavedModel that was generated with tf.keras. The pruning script is as follows:

svmod = tf.saved_model.load(fn) #version 1
#svmod = tfk.experimental.load_from_saved_model(fn) #version 2
feeds = ['foo:0']
fetches = ['bar:0']
svmod2 = svmod.prune(feeds=feeds, fetches=fetches)
tf.saved_model.save(svmod2, '/tmp/saved_model/') #version 1
#tfk.experimental.export_saved_model(svmod2, '/tmp/saved_model/') #version 2

如果我使用版本 #1 修剪工作但在保存时给出 ValueError: Expected a Trackable object for export.在版本 2 中,没有 prune() 方法.

If I use version #1 pruning works but gives ValueError: Expected a Trackable object for export when saving. In version 2, there is no prune() method.

如何修剪 TF2.0 Keras SavedModel?

How can I prune a TF2.0 Keras SavedModel?

推荐答案

看起来您在版本 1 中修剪模型的方式很好;根据您的错误消息,无法保存生成的修剪模型,因为它不可跟踪",这是使用 tf.saved_model.save 保存模型的必要条件.制作可跟踪对象的一种方法是从 tf.Module 类,如使用 SavedModel 格式具体函数.下面是一个尝试保存 tf.function 对象(由于对象不可跟踪而失败),从tf.module 继承,并保存生成的对象:

It looks like the way you are pruning the model in version 1 is fine; according to your error message, the resulting pruned model cannot be saved because it is not "trackable", which is a necessary condition for saving a model with tf.saved_model.save. One way to make a trackable object is to inherit from the tf.Module class, as described in the guides for using the SavedModel format and concrete functions. Below is an example of trying to save a tf.function object (which fails because the object is not trackable), inheriting fromtf.module, and saving the resulting object:

(使用 Python 3.7.6 版、TensorFlow 2.1.0 版和 NumPy 1.18.1 版)

(Using Python version 3.7.6, TensorFlow version 2.1.0, and NumPy version 1.18.1)

import tensorflow as tf, numpy as np

# Define a random TensorFlow function and generate a reference output
conv_filter = tf.random.normal([1, 2, 4, 2], seed=1254)
@tf.function
def conv_model(x):
    return tf.nn.conv2d(x, conv_filter, 1, "SAME")

input_tensor = tf.ones([1, 2, 3, 4])
output_tensor = conv_model(input_tensor)
print("Original model outputs:", output_tensor, sep="\n")

# Try saving the model: it won't work because a tf.function is not trackable
export_dir = "./tmp/"
try: tf.saved_model.save(conv_model, export_dir)
except ValueError: print(
    "Can't save {} object because it's not trackable".format(type(conv_model)))

# Now define a trackable object by inheriting from the tf.Module class
class MyModule(tf.Module):
    @tf.function
    def __call__(self, x): return conv_model(x)

# Instantiate the trackable object, and call once to trace-compile a graph
module_func = MyModule()
module_func(input_tensor)
tf.saved_model.save(module_func, export_dir)

# Restore the model and verify that the outputs are consistent
restored_model = tf.saved_model.load(export_dir)
restored_output_tensor = restored_model(input_tensor)
print("Restored model outputs:", restored_output_tensor, sep="\n")
if np.array_equal(output_tensor.numpy(), restored_output_tensor.numpy()):
    print("Outputs are consistent :)")
else: print("Outputs are NOT consistent :(")

控制台输出:

Original model outputs:
tf.Tensor(
[[[[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]

  [[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Can't save <class 'tensorflow.python.eager.def_function.Function'> object
because it's not trackable
Restored model outputs:
tf.Tensor(
[[[[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]

  [[-2.3629642   1.2904963 ]
   [-2.3629642   1.2904963 ]
   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Outputs are consistent :)

因此您应该尝试如下修改您的代码:

Therefore you should try modifying your code as follows:

svmod = tf.saved_model.load(fn) #version 1
svmod2 = svmod.prune(feeds=['foo:0'], fetches=['bar:0'])

class Exportable(tf.Module):
    @tf.function
    def __call__(self, model_inputs): return svmod2(model_inputs)

svmod2_export = Exportable()
svmod2_export(typical_input)    # call once with typical input to trace-compile
tf.saved_model.save(svmod2_export, '/tmp/saved_model/')

如果你不想从 tf.Module 继承,你也可以只实例化一个 tf.Module 对象并添加一个 tf.function method/callable 属性,通过如下方式替换该部分代码:

If you don't want to inherit from tf.Module, you can alternatively just instantiate a tf.Module object and add a tf.function method/callable attribute by replacing that section of code as follows:

to_export = tf.Module()
to_export.call = tf.function(conv_model)
to_export.call(input_tensor)
tf.saved_model.save(to_export, export_dir)

restored_module = tf.saved_model.load(export_dir)
restored_func = restored_module.call

这篇关于TF2.0 中的 saved_model.prune()的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆