无法获得 tensorflow DNNClassifier 的预测 [英] Cannot get predictions of tensorflow DNNClassifier

查看:32
本文介绍了无法获得 tensorflow DNNClassifier 的预测的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 MNIST 教程中的代码:

I'm using the code from the MNIST tutorial:

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=2,
                                            model_dir="/tmp/iris_model")

classifier.fit(x=np.array(train, dtype = 'float32'),
               y=np.array(y_tr, dtype = 'int64'),
               steps=2000)

accuracy_score = classifier.evaluate(x=np.array(test, dtype = 'float32'),
                                     y=y_test)["auc"]
print('AUC: {0:f}'.format(accuracy_score))

from tensorflow.contrib.learn import SKCompat
ds_test_ar = np.array(ds_test, dtype = 'float32')

ds_predict_tf = classifier.predict(input_fn = _my_predict_data)
print('Predictions: {}'.format(str(ds_predict_tf)))

但最后我得到了以下结果而不是预测:

but at the end I got the following result instead of the predictions:

Predictions: <generator object DNNClassifier.predict.<locals>.<genexpr> at 0x000002CE41101CA8>

我做错了什么?

推荐答案

您收到并保存到 ds_predict_tf 的是一个生成器表达式.要打印它,您可以执行以下操作:

What you received and saved to ds_predict_tf is a generator expression. To print it you can do:

for i in ds_predict_tf:
    print i

print(list(ds_predict_tf))

您可以在此处阅读有关geneexpr的更多信息.

You can read more about genexpr here.

这篇关于无法获得 tensorflow DNNClassifier 的预测的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆