TensorFlow v2:替换 tf.contrib.predictor.from_saved_model [英] TensorFlow v2: Replacement for tf.contrib.predictor.from_saved_model

查看:57
本文介绍了TensorFlow v2:替换 tf.contrib.predictor.from_saved_model的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

到目前为止,我使用 tf.contrib.predictor.from_saved_model 来加载 SavedModel(tf.estimator 模型类).然而,不幸的是,这个功能在 TensorFlow v2 中被删除了.到目前为止,在 TensorFlow v1 中,我的编码如下:

So far, I was using tf.contrib.predictor.from_saved_model to load a SavedModel (tf.estimator model class). However, this function has unfortunately been removed in TensorFlow v2. So far, in TensorFlow v1, my coding was the following:

 predict_fn = predictor.from_saved_model(model_dir + '/' + model, signature_def_key='predict')

 prediction_feed_dict = dict()

 for key in predict_fn._feed_tensors.keys():

     #forec_data is a DataFrame holding the data to be fed in 
     for index in forec_data.index:
         prediction_feed_dict[key] = [ [ forec_data.loc[index][key] ] ]

 prediction_complete = predict_fn(prediction_feed_dict)

使用 tf.saved_model.load,我在 TensorFlow v2 中尝试了以下失败:

Using tf.saved_model.load, I unsuccessfully tried the following in TensorFlow v2:

 model = tf.saved_model.load(model_dir + '/' + latest_model)
 model_fn = model.signatures['predict']

 prediction_feed_dict = dict()

 for key in model_fn._feed_tensors.keys(): #<-- no replacement for _feed_tensors.keys() found

     #forec_data is a DataFrame holding the data to be fed in 
     for index in forec_data.index:
         prediction_feed_dict[key] = [ [ forec_data.loc[index][key] ] ]

 prediction_complete = model_fn(prediction_feed_dict) #<-- no idea if this is anyhow close to correct

所以我的问题是(都在 TensorFlow v2 的上下文中):

So my questions are (both in the context of TensorFlow v2):

  1. 如何替换 _feed_tensors.keys()?
  2. 如何使用加载了 tf.saved_model.loadtf.estimator 模型以直接的方式进行推理
  1. How can I replace _feed_tensors.keys()?
  2. How to inference in a straightforward way using a tf.estimator model loaded with tf.saved_model.load

非常感谢,感谢您的帮助.

Thanks a lot, any help is appreciated.

注意:此问题与发布的问题不重复 此处,因为那里提供的答案都依赖于已在 TensorFlow v2 中删除的 TensorFlow v1 的功能.

Note: This question is not a duplicate of the one posted here as the answers provided there all rely on features of TensorFlow v1 that have been removed in TensorFlow v2.

问题 postet 这里好像问的基本一样,但是直到现在(2020-01-22)也没有人回答.

The question postet here seems to ask basically the same thing, but until now (2020-01-22) is also unanswered.

推荐答案

希望您使用类似于下面提到的代码保存了 Estimator 模型:

Hope you have Saved the Estimator Model using the code similar to that mentioned below:

input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))
export_path = estimator.export_saved_model(
  "/tmp/from_estimator/", serving_input_fn)

您可以使用下面提到的代码加载模型:

You can Load the Model using the code mentioned below:

imported = tf.saved_model.load(export_path)

要通过传递输入特征来预测使用您的模型,您可以使用以下代码:

To Predict using your Model by passing the Input Features, you can use the below code:

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](examples=tf.constant([example.SerializeToString()]))

print(predict(1.5))
print(predict(3.5))

更多详情,请参考这个链接,其中使用 TF Estimator 保存模型进行了解释.

For more details, please refer this link in which Saved Models using TF Estimator are explained.

这篇关于TensorFlow v2:替换 tf.contrib.predictor.from_saved_model的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆