当批量大小大于 1 时,tensorflow 数据集 API 无法稳定工作 [英] tensorflow dataset API doesn't work stably when batch size is greater than 1

查看:78
本文介绍了当批量大小大于 1 时,tensorflow 数据集 API 无法稳定工作的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我将一组固定长度和可变长度的特征放到一个 tf.train.SequenceExample 中.

I put a group of fixed-length and variable-length features into one tf.train.SequenceExample.

context_features
    length,            scalar,                    tf.int64
    site_code_raw,     scalar,                    tf.string
    Date_Local_raw,    scalar,                    tf.string
    Time_Local_raw,    scalar,                    tf.string
Sequence_features
    Orig_RefPts,       [#batch, #RefPoints, 4]    tf.float32
    tgt_location,      [#batch, 3]                tf.float32
    tgt_val            [#batch, 1]                tf.float32

#RefPoints 的值对于不同的序列示例是可变的.我将它的值存储在 context_features 中的 length 特征中.其余功能具有固定大小.

The value of #RefPoints is variable for different sequence examples. I store its value in length feature in the context_features. The rest features have fixed sizes.

这是我用来阅读 & 的代码解析数据:

Here is the code I am using to read & parse the data:

def read_batch_DatasetAPI(
    filenames, 
    batch_size = 20, 
    num_epochs = None, 
    buffer_size = 5000):

    dataset = tf.contrib.data.TFRecordDataset(filenames)
    dataset = dataset.map(_parse_SeqExample1)
    if (buffer_size is not None):
        dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    # next_element contains a tuple of following tensors
    # length,            scalar,                    tf.int64
    # site_code_raw,     scalar,                    tf.string
    # Date_Local_raw,    scalar,                    tf.string
    # Time_Local_raw,    scalar,                    tf.string
    # Orig_RefPts,       [#batch, #RefPoints, 4]    tf.float32
    # tgt_location,      [#batch, 3]                tf.float32
    # tgt_val            [#batch, 1]                tf.float32

    return iterator, next_element

def _parse_SeqExample1(in_SeqEx_proto):

    # Define how to parse the example
    context_features = {
        'length': tf.FixedLenFeature([], dtype=tf.int64),
        'site_code': tf.FixedLenFeature([], dtype=tf.string),
        'Date_Local': tf.FixedLenFeature([], dtype=tf.string),
        'Time_Local': tf.FixedLenFeature([], dtype=tf.string) #,
    }

    sequence_features = {
        "input_features": tf.VarLenFeature(dtype=tf.float32),
        'tgt_location_features': tf.FixedLenSequenceFeature([3], dtype=tf.float32),
        'tgt_val_feature': tf.FixedLenSequenceFeature([1], dtype=tf.float32)   
    }                                                        

    context, sequence = tf.parse_single_sequence_example(
      in_SeqEx_proto, 
      context_features=context_features,
      sequence_features=sequence_features)

    # distribute the fetched context and sequence features into tensors
    length = context['length']
    site_code_raw = context['site_code']
    Date_Local_raw = context['Date_Local']
    Time_Local_raw = context['Time_Local']

    # reshape the tensors according to the dimension definition above
    Orig_RefPts = sequence['input_features'].values
    Orig_RefPts = tf.reshape(Orig_RefPts, [-1, 4])
    tgt_location = sequence['tgt_location_features']
    tgt_location = tf.reshape(tgt_location, [-1])
    tgt_val = sequence['tgt_val_feature']
    tgt_val = tf.reshape(tgt_val, [-1])

    return length, site_code_raw, Date_Local_raw, Time_Local_raw, \
        Orig_RefPts, tgt_location, tgt_val

当我使用 batch_size = 1 调用 read_batch_DatasetAPI 时(见下面的代码),它可以一一处理所有(大约 200,000 个)序列示例,没有任何问题.但是,如果我将 batch_size 更改为大于 1 的任何数字,它只会在获取 320 到 700 个序列示例后停止而没有任何错误消息.我不知道如何解决这个问题.任何帮助表示赞赏!

When I call read_batch_DatasetAPI with batch_size = 1 (see the code below), it can process all (around 200,000) Sequence Examples one-by-one without any problem. But if I change the batch_size to any number greater than 1, it simply stopped after fetching 320 to 700 Sequence Examples without any error message. I don't know how to solve this problem. Any help is appreciated!

# the iterator to get the next_element for one sample (in sequence)
iterator, next_element = read_batch_DatasetAPI(
    in_tf_FWN,  # the file name of the tfrecords containing ~200,000 Sequence Examples
    batch_size = 1,  # works when it is 1, doesn't work if > 1
    num_epochs = 1,
    buffer_size = None)

# tf session initialization
sess = tf.Session()
sess.run(tf.global_variables_initializer())

## reset the iterator to the beginning
sess.run(iterator.initializer)

try:
    step = 0

    while (True):

        # get the next batch data
        length, site_code_raw, Date_Local_raw, Time_Local_raw, \
        Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)

        step = step + 1

except tf.errors.OutOfRangeError:
    # Task Done (all SeqExs have been visited)
    print("closing ", in_tf_FWN)

except ValueError as err:
    print("Error: {}".format(err.args))

except Exception as err:
    print("Error: {}".format(err.args))

推荐答案

我看到了一些帖子 (示例 1示例 2) 提到了新的 dataset 函数 from_generator (https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/Dataset#from_generator).我还不确定如何使用它来解决我的问题.任何人都知道该怎么做,请将其作为新答案发布.谢谢!

I saw some posts (Example 1 and Example 2) mentioning the new dataset function from_generator (https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/Dataset#from_generator). I'm not sure how to use it to solve my problem yet. Anyone knows how to do it, please post it as a new answer. Thank you!

这是我目前的诊断和问题的解决方案:

Here is my current diagnosis and solution to my question:

序列长度的变化 (#RefPoints) 导致了问题.dataset.map(_parse_SeqExample1) 仅在 #RefPoints 恰好在批处理中相同时才有效.这就是为什么如果 batch_size 为 1,它总是有效,但如果它大于 1,它会在某个时候失败.

The variation of the sequence length (#RefPoints) caused the problem. The dataset.map(_parse_SeqExample1) only works if the #RefPointss happen to be the same in the batch. That's why if the batch_size was 1, it always worked, but if it was greater than 1, it failed at some point.

我发现datasetpadded_batch 函数,它可以将variable-length 填充到batch 中的最大长度.进行了一些更改以暂时解决我的问题(我想 from_generator 将是我的案例的真正解决方案):

I found that dataset has the padded_batch function which can pad the variable-length to the maximum length in the batch. A few changes were made to temporarily solve my problem (I guess from_generator will be the real solution to my case):

  1. _parse_SeqExample1函数中,return语句改为

  1. In the _parse_SeqExample1 function, the return statement was changed to

return tf.tuple([length, site_code_raw, Date_Local_raw, Time_Local_raw, \Orig_RefPts, tgt_location, tgt_val])

read_batch_DatasetAPI函数中,声明

dataset = dataset.batch(batch_size)

改为

dataset = dataset.padded_batch(batch_size, padded_shapes=(tf.TensorShape([]),tf.TensorShape([]),tf.TensorShape([]),tf.TensorShape([]),tf.TensorShape([None, 4]),tf.TensorShape([3]),tf.TensorShape([1])))

最后将fetch语句从

Finally, change the fetch statement from

长度、site_code_raw、Date_Local_raw、Time_Local_raw、\orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)

[长度、site_code_raw、Date_Local_raw、Time_Local_raw、\orig_RefPts_val, tgt_location, tgt_vale] = sess.run(next_element)

注意:我不知道为什么,这只适用于当前的 tf-nightly-gpu 版本不是gpuv1.3.

Note: I don't know why, this only works on the current tf-nightly-gpu version not the tensorflow-gpu v1.3.

这篇关于当批量大小大于 1 时,tensorflow 数据集 API 无法稳定工作的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆