使用输入fn在Tensorflow估计器中进行预测 [英] Predict in Tensorflow estimator using input fn

查看:202
本文介绍了使用输入fn在Tensorflow估计器中进行预测的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用来自 https:中的教程代码: //github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py 并且代码可以正常工作,直到我尝试做出预测而不是仅仅对其进行评估。我试图制作另一个看起来像这样的预测功能(只需删除参数y):

I use the tutorial code from https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py and the code works fine until I tried to make a prediction instead of just evaluate it. I tried to make another function for prediction that look like this (by just removing parameter y):

def input_fn_predict(data_file, num_epochs, shuffle):
  """Input builder function."""
  df_data = pd.read_csv(
      tf.gfile.Open(data_file),
      names=CSV_COLUMNS,
      skipinitialspace=True,
      engine="python",
      skiprows=1)
  # remove NaN elements
  df_data = df_data.dropna(how="any", axis=0)
  labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
  return tf.estimator.inputs.pandas_input_fn( #removed paramter y
      x=df_data,
      batch_size=100,
      num_epochs=num_epochs,
      shuffle=shuffle,
      num_threads=5)

这样称呼它:

predictions = m.predict(
      input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
  )
  for i, p in enumerate(predictions):
      print(i, p)




  • 我做对了吗?

  • 为什么我得到预测81404而不是16282(测试文件中的行数)?

  • 每行包含以下内容:


  • {'概率':array([0.78595656,0.21404342],dtype = float32),
    'logits':array([-1.3007226],dtype = float32),'classes':array(['0'],
    dtype = object),'class_ids':array( [0]),后勤:array([
    0.21404341],dtype = float32)}

    {'probabilities': array([ 0.78595656, 0.21404342], dtype=float32), 'logits': array([-1.3007226], dtype=float32), 'classes': array(['0'], dtype=object), 'class_ids': array([0]), 'logistic': array([ 0.21404341], dtype=float32)}

    我该如何

    推荐答案

    您需要设置 shuffle = False ,因为预测新标签,您需要维护数据顺序。

    You need to set shuffle=False since to predict new label, you need to maintain data order.

    下面是我运行预测的代码(我已经测试过了)。输入文件就像测试数据(在csv中一样),但是没有标签列。

    Below is my code to run the prediction (I've tested it). The input file is like test data (in csv), but there is no label column.

    
    
        def predict_input_fn(data_file):
            global CSV_COLUMNS
            CSV_COLUMNS = CSV_COLUMNS[:-1]
            df_data = pd.read_csv(
                tf.gfile.Open(data_file),
                names=CSV_COLUMNS,
                skipinitialspace=True,
                engine='python',
                skiprows=1
            )
    
            # remove NaN elements
            df_data = df_data.dropna(how='any', axis=0)
    
            return tf.estimator.inputs.pandas_input_fn(
                x=df_data,
                num_epochs=1,
               shuffle=False
            )
    
    

    调用它:

    
    
        predict_file_name = 'tutorials/data/adult.predict'
        results = m.predict(
            input_fn=predict_input_fn(predict_file_name)
        )
        for result in results:
            print 'result: {}'.format(result)
    
    

    预测一个样本的结果如下:

    The prediction result for one sample is below:

    
    
        {
            'probabilities': array([0.78595656, 0.21404342], dtype = float32),
            'logits': array([-1.3007226], dtype = float32),
            'classes': array(['0'], dtype = object),
            'class_ids': array([0]),
            'logistic': array([0.21404341], dtype = float32)
        }
    
    

    每个字段表示什么


    • '概率':array([0.78595656,0.21404342],dtype = float32)
      预测
      的输出标签为class-0(在这种情况下< = 50K)置信度0.78595656

    • 'logits':array([-1.3007226],dtype = float32)

      等式1 /中z的值(1 + e ^(-z))是-1.3。

    • 'classes':array(['0'],dtype = object)

      类标签是0

    • 'probabilities': array([0.78595656, 0.21404342], dtype = float32).
      It predicts the output label is class-0 (in this case <=50K) with confidence 0.78595656
    • 'logits': array([-1.3007226], dtype = float32)
      The value of z in equation 1/(1+e^(-z)) is -1.3.
    • 'classes': array(['0'], dtype = object)
      The class label is 0

    这篇关于使用输入fn在Tensorflow估计器中进行预测的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆