如何检查Tensorflow LinearClassifier的特征权重? [英] How to examine the feature weights of a Tensorflow LinearClassifier?

查看:234
本文介绍了如何检查Tensorflow LinearClassifier的特征权重?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试理解 带有TensorFlow的大型线性模型 文档.这些文档激发了这些模型,如下所示:

I am trying to understand the Large-scale Linear Models with TensorFlow documentation. The docs motivate these models as follows:

线性模型比神经模型更容易解释和调试 网. 您可以检查分配给每个功能的权重以得出 找出对预测影响最大的是什么.

Linear model can be interpreted and debugged more easily than neural nets. You can examine the weights assigned to each feature to figure out what's having the biggest impact on a prediction.

因此,我从随附的> TensorFlow线性模型教程中运行了扩展代码示例.特别是,我运行了 示例GitHub中的代码 ,且model-type标志设置为wide.可以正确运行并生成accuracy: 0.833733,类似于Tensorflow网页上的accuracy: 0.83557522.

So I ran the extended code example from the accompanying TensorFlow Linear Model Tutorial. In particular, I ran the example code from GitHub with the model-type flag set to wide. This correctly ran and produced accuracy: 0.833733, similar to the accuracy: 0.83557522 on the Tensorflow web page.

该示例使用tf.estimator.LinearClassifier训练权重.但是,与引用的能够检查权重的动机相反,我在

The example uses a tf.estimator.LinearClassifier to train the weights. However, in contrast to the quoted motivation of being able to examine the weights, I can't find any function to actually extract the trained weights in the LinearClassifier documentation.

问题:如何访问tf.estimator.LinearClassifier中各个功能列的训练权重?我希望能够提取NumPy数组中的所有权重.

Question: how do I access the trained weights for the various feature columns in a tf.estimator.LinearClassifier? I'd prefer to be able to extract all the weights in a NumPy array.

注意:我来自一个R环境,其中线性回归/分类模型具有coefs方法来提取学习的权重.我希望能够在同一数据集上比较R和TensorFlow中的线性模型.

Note: I am coming from an R environment where linear regression / classification models have a coefs method to extract learned weights. I want to be able to compare linear models in both R and TensorFlow on the same datasets.

推荐答案

用Estimator训练模型后,您可以使用 tf.train.list_variables 查找模型权重的名称.

After training the model with Estimator, you could use the tf.train.load_variable to retrieve the weights from checkpoint. You can use tf.train.list_variables to find the names for model weights.

有计划也直接在Estimator中添加此支持.

There are plans to add this support in Estimator directly also.

这篇关于如何检查Tensorflow LinearClassifier的特征权重?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆