当我使用TensorFlow解码`csv`文件时,如何将'tf.map_fn'应用于SparseTensor? [英] When I use TensorFlow to decode `csv` file, how can I apply 'tf.map_fn' to SparseTensor?

查看:212
本文介绍了当我使用TensorFlow解码`csv`文件时,如何将'tf.map_fn'应用于SparseTensor?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当我使用以下代码时

import tensorflow as tf

# def input_pipeline(filenames, batch_size):
#     # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
#     dataset = (tf.contrib.data.TextLineDataset(filenames)
#                .map(lambda line: tf.decode_csv(
#                     line, record_defaults=[['1'], ['1'], ['1']], field_delim='-'))
#                .shuffle(buffer_size=10)  # Equivalent to min_after_dequeue=10.
#                .batch(batch_size))

#     # Return an *initializable* iterator over the dataset, which will allow us to
#     # re-initialize it at the beginning of each epoch.
#     return dataset.make_initializable_iterator() 

def decode_func(line):
    record_defaults = [['1'],['1'],['1']]
    line = tf.decode_csv(line, record_defaults=record_defaults, field_delim='-')
    str_to_int = lambda r: tf.string_to_number(r, tf.int32)
    query = tf.string_split(line[:1], ",").values
    title = tf.string_split(line[1:2], ",").values
    query = tf.map_fn(str_to_int, query, dtype=tf.int32)
    title = tf.map_fn(str_to_int, title, dtype=tf.int32)
    label = line[2]
    return query, title, label

def input_pipeline(filenames, batch_size):
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
    dataset = tf.contrib.data.TextLineDataset(filenames)
    dataset = dataset.map(decode_func)
    dataset = dataset.shuffle(buffer_size=10)  # Equivalent to min_after_dequeue=10.
    dataset = dataset.batch(batch_size)

    # Return an *initializable* iterator over the dataset, which will allow us to
    # re-initialize it at the beginning of each epoch.
    return dataset.make_initializable_iterator() 


filenames=['2.txt']
batch_size = 3
num_epochs = 10
iterator = input_pipeline(filenames, batch_size)

# `a1`, `a2`, and `a3` represent the next element to be retrieved from the iterator.    
a1, a2, a3 = iterator.get_next()

with tf.Session() as sess:
    for _ in range(num_epochs):
        print(_)
        # Resets the iterator at the beginning of an epoch.
        sess.run(iterator.initializer)
        try:
            while True:
                a, b, c = sess.run([a1, a2, a3])
                print(type(a[0]), b, c)
        except tf.errors.OutOfRangeError:
            print('stop')
            # This will be raised when you reach the end of an epoch (i.e. the
            # iterator has no more elements).
            pass                 

        # Perform any end-of-epoch computation here.
        print('Done training, epoch reached')

脚本崩溃并没有返回任何结果,并且在到达a, b, c = sess.run([a1, a2, a3])时停止,但是在我发表评论

The script crashed didn't return any results, and stop when reached a, b, c = sess.run([a1, a2, a3]), but when I commented

query = tf.map_fn(str_to_int, query, dtype=tf.int32)
title = tf.map_fn(str_to_int, title, dtype=tf.int32)

它可以工作并返回结果.

It works and return the results.

2.txt中,数据格式类似于

1,2,3-4,5-0
1-2,3,4-1
4,5,6,7,8-9-0

此外,为什么返回结果是byte-like对象而不是str?

In addition, why the return results are byte-like object rather than str?

推荐答案

我看了一眼,看来如果您更换:

I had a look and it appears that if you replace:

query = tf.map_fn(str_to_int, query, dtype=tf.int32)
title = tf.map_fn(str_to_int, title, dtype=tf.int32)
label = line[2]

作者

query = tf.string_to_number(query, out_type=tf.int32)
title = tf.string_to_number(title, out_type=tf.int32)
label = tf.string_to_number(line[2], out_type=tf.int32)

它工作正常.

似乎没有2个嵌套的TensorFlow lambda函数(tf.map_fnDataSet.map)只是行不通.幸运的是,它过于复杂.

It appears that having 2 nested TensorFlow lambda functions (the tf.map_fnand the DataSet.map) just don't work. Luckily enough, it was over complicated.

关于您的第二个问题,我将其作为输出:

Regarding your second question, I got this as output:

[(array([4, 5, 6, 7, 8], dtype=int32), array([9], dtype=int32), 0)]
<type 'numpy.ndarray'>

这篇关于当我使用TensorFlow解码`csv`文件时,如何将'tf.map_fn'应用于SparseTensor?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆