如何在 tensorflow 上加载和使用保存的模型? [英] how to load and use a saved model on tensorflow?

查看:34
本文介绍了如何在 tensorflow 上加载和使用保存的模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我发现了两种在 Tensorflow 中保存模型的方法:tf.train.Saver()SavedModelBuilder.但是,在以第二种方式加载后,我找不到有关使用该模型的文档.

I have found 2 ways to save a model in Tensorflow: tf.train.Saver() and SavedModelBuilder. However, I can't find documentation on using the model after it being loaded the second way.

注意:我想使用 SavedModelBuilder 方式,因为我用 Python 训练模型,并将在服务时以另一种语言(Go)使用它,而且似乎 SavedModelBuilder 在这种情况下是唯一的方法.

Note: I want to use SavedModelBuilder way because I train the model in Python and will use it at serving time in another language (Go), and it seems that SavedModelBuilder is the only way in that case.

这对 tf.train.Saver() 很有效(第一种方式):

This works great with tf.train.Saver() (first way):

model = tf.add(W * x, b, name="finalnode")

# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")

# load
saver.restore(sess, "/tmp/model")

# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.

model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})

tf.saved_model.builder.SavedModelBuilder() 定义在 Readme 但在使用 tf.saved_model.loader.load(sess, [], export_dir)) 加载模型后,我找不到有关获取的文档回到节点(参见上面代码中的finalnode")

tf.saved_model.builder.SavedModelBuilder() is defined in the Readme but after loading the model with tf.saved_model.loader.load(sess, [], export_dir)), I can't find documentation on getting back at the nodes (see "finalnode" in the code above)

推荐答案

缺少的是签名

# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
        "model": tf.saved_model.signature_def_utils.predict_signature_def(
            inputs= {"x": x},
            outputs= {"finalnode": model})
        })
builder.save()

# loading
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ["tag"], export_dir)
    graph = tf.get_default_graph()
    x = graph.get_tensor_by_name("x:0")
    model = graph.get_tensor_by_name("finalnode:0")
    print(sess.run(model, {x: [5, 6, 7, 8]}))

这篇关于如何在 tensorflow 上加载和使用保存的模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆