有没有通用的方法来为 tf.decode_csv 设置 record_defaults? [英] is there a general way to set the record_defaults for tf.decode_csv?

查看:37
本文介绍了有没有通用的方法来为 tf.decode_csv 设置 record_defaults?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是 TensorFlow 的新手,因此遇到了一些困难.在我完成数据的预处理后,因为我不知道如何生成批次, 我将数据保存为 csv.然后我尝试在 tensorflow 中读取它,但是在解码 ('tf.decode_csv' ) 时,record_defaults 参数是必要的.但是我的数据中有这么多列,因此分配record_defaults确实需要时间.那么我们如何将所有值设置为 0(假设我们不知道具体的列数)?

I am fairly new to TensorFlow and therefore have a few difficulties.After I finished the pre-process for my data,because I don't know how to generate batches, I save the data as a csv. And then I try to read it in tensorflow, but while decoding ('tf.decode_csv' ),the record_defaults argument is necessary. But there are so many columns in my data, so it really takes time to assign the record_defaults. So how can we just set all the values 0(assuming we don't know the specific numbers of columns)?

推荐答案

我试图找到一种方法来跳过/解决这个问题,但似乎没有办法忽略记录默认值.

I tried to find a way to skip/get around that, but seems like there's no way to ignore record defaults.

但是因为无论如何你必须知道你在 CSV 中的行的长度才能在 Tensorflow 中读取它,所以到目前为止我发现解决这个问题的最简单方法是简单地用这个单行预填充默认值:

But since you HAVE to know the length of your row in CSV anyhow in order to read it in Tensorflow, the easiest way I have found so far to get around this is to simply pre-fill the defaults with this one-liner:

rDefaults = [['a'] * num_cells_in_your_row]

所以我的数据是每行约 800 列,这样我就不必单独处理它们.此外,在我的情况下,读入的数据需要采用字符串格式,但您可以将初始值设置为零/等.而不是 'a' 以防你需要浮点数...

So my data is say ~800 columns in each row, and this way I don't have to address them individually. Also, in my case the data read in needs to be in String format, but you can set the initial values to zero/etc. instead of 'a' in case you need floats...

*** 更新:

正如所讨论的,上面的内容并不限制您的数据类型在整行中是统一的.以下是将行中的特定单元格转换为所需数据类型的方法:

As discussed, the above doesn't restrict you to case where your data types are uniform across the whole row. Here's how you convert specific cells within the row to the data type you need:

rDefaults = [['a'] * num_cells_in_your_row]

def read_from_csv(filename_queue):
    reader = tf.TextLineReader(skip_header_lines=False)
    _, csv_row = reader.read(filename_queue)
    data = tf.decode_csv(csv_row, record_defaults=rDefaults)
    dateLbl = tf.slice(data, [0], [CD]) # portion of the row that is 'String'
    crossLbl = tf.slice(data, [CD], [CC])# Also 'String'
# this part converts the rest of from String to float:
    obs = tf.string_to_number(tf.slice(data, [CD + CC], [SEQLEN]), tf.float32)
    return dateLbl, crossLbl, obs

希望这有帮助...

这篇关于有没有通用的方法来为 tf.decode_csv 设置 record_defaults?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆