Tensorflow Estimator API:如何从输入函数传递参数 [英] Tensorflow Estimator API: How to pass parameter from input function

查看:24
本文介绍了Tensorflow Estimator API:如何从输入函数传递参数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试为我的模型添加类权重作为超参数,但是为了计算权重,我需要读取输入数据,这发生在 input_fn 内部,然后传递给 estimator.fit().input_fn 的输出只是特征,标签应该具有相同的形状 num_examples * num_features.我的问题 - 有没有办法将数据从 input_fn 传播到 model_fn 的超参数映射?或者作为替代 - 也许 input_fn 数据集有一个包装器,它允许对少数/欠采样多数以及批处理进行过采样 - 在这种情况下,我不需要任何参数来传播.

I'm trying to add class weights as a hyperparameter for my model, but to calculate weight I need to read input data, this happens inside input_fn which then passed to estimator.fit(). An output of input_fn are only features, labels which should have same shape num_examples * num_features. My questions - is there any way to propagate data from input_fn to model_fn's hyperparameter map? Or as alternative - maybe there is a wrapper for input_fn dataset which allows to oversample minority/undersample majority along with batching - in this case I would not need any parameter to propagate.

推荐答案

特征和标签都可以是张量字典(不仅仅是一个张量).张量可以是您想要的任何形状,但通常为 num_examples * ...

Both features and labels can be dictionary of tensors (not just one tensor). The tensors can be any shape you want though it's common to be num_examples * ...

如果您不使用任何预定义的估计器,最简单的方法是添加另一个具有计算权重所需的特征,计算模型中的权重,然后使用它们(乘以损失或将其作为一个参数).

If you don't use any of the predefined estimators, the easiest way would be to add another feature with what you need to compute the weights, compute the weights in the model then use them (multiply the loss or pass it as a parameter).

您还可以访问 input_fn 中的超参数,以便您可以在那里计算权重并将其添加为单独的列.

You also have access to hyper parameters inside the input_fn so you can compute the weight there and add it as a separate column.

如果您使用固定的估算器,请查看文档.我看到他们中的大多数都支持 weight_column_name.在这种情况下,只需将您在特征字典中用于权重值的名称命名为它即可.

If you use a canned estimator check the documentation. I see most of them support a weight_column_name. In this case just give it the name you used in the features dictionary for the weight values.

或者,如果所有其他方法都失败了,您可以在将数据提供给 tensorflow 之前以您想要的方式对数据进行采样.

Alternatively, if all else fails you can sample the data the way you want before you feed it to tensorflow.

这篇关于Tensorflow Estimator API:如何从输入函数传递参数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆