Tensorflow LinearRegressor 功能不能有等级 0 [英] Tensorflow LinearRegressor Feature Cannot have rank 0

查看:29
本文介绍了Tensorflow LinearRegressor 功能不能有等级 0的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在学习教程,但未能为基于 y=x 生成的数据集构建线性回归器.这是我的代码的最后一部分,你可以找到完整的源代码 如果你想重现我的错误:

I am following the tutorial but failed to build a linear regressor for a dataset generated on top of y=x. Here is the last part of my code, and you can find the complete source code here if you want to reproduce my error:

_CSV_COLUMN_DEFAULTS = [[0],[0]]
_CSV_COLUMNS = ['x', 'y']

def input_fn(data_file):

    def parse_csv(value):
        print('Parsing', data_file)
        columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
        features = dict(zip(_CSV_COLUMNS, columns))
        labels = features.pop('y')
        return features, labels

    # Extract lines from input files using the Dataset API.
    dataset = tf.data.TextLineDataset(data_file)
    dataset = dataset.map(parse_csv)

    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

x = tf.feature_column.numeric_column('x')
base_columns = [x]

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearRegressor(model_dir=model_dir,     feature_columns=base_columns)

model = model.train(input_fn=lambda: input_fn(data_file=file_path))

不知何故此代码将失败并显示错误消息

Somehow this code will fail with error message

ValueError: Feature (key: x) cannot have rank 0. Give: Tensor("IteratorGetNext:0", shape=(), dtype=int32, device=/device:CPU:0)

由于 tensorflow 的性质,我发现根据给定的消息检查它真正出错的地方有点困难.

Due to the nature of tensorflow, I found it a bit hard to inspect where it really went wrong based on the given message.

推荐答案

据我所知,值的第一个维度是 batch_size.所以当input_fn返回数据时,应该批量返回数据.

As far as I can tell, the first dimension of the values is meant to be the batch_size. So when input_fn returns the data, it should return data as a batch.

一旦您将数据作为批处理返回,它就会起作用,例如:

It works once you return the data as a batch, e.g.:

dataset = tf.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv)
dataset = dataset.batch(10) # or any other batch size

这篇关于Tensorflow LinearRegressor 功能不能有等级 0的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆