使用 tf.data 批量处理顺序数据 [英] Batch sequential data with tf.data

查看:31
本文介绍了使用 tf.data 批量处理顺序数据的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

让我们考虑一个玩具数据集,它是有序的,具有两个特征:

Let's consider a toy dataset, ordered, with two features:

  • value(例如 1, 2, 3, 4, 5, 111, 222, 333, 444, 555)
  • sequence_id(例如 0, 0, 0, 0, 0, 1, 1, 1, 1, 1)
  • value (e.g. 1, 2, 3, 4, 5, 111, 222, 333, 444, 555)
  • sequence_id (e.g. 0, 0, 0, 0, 0, 1, 1, 1, 1, 1)

这个数据基本上由两个连接的扁平序列组成,1, 2, 3, 4, 5 (sequence 0), and 111, 222, 333, 444, 555(序列 1).

This data basically consists of two flattened sequences concatenated, 1, 2, 3, 4, 5 (sequence 0), and 111, 222, 333, 444, 555 (sequence 1).

我想生成大小为 t(比如 3)的序列,由来自同一序列(sequence_id)的连续元素组成,我不希望序列具有属于不同 sequence_id 的元素.

I would like to generate sequences of size t (say 3) consisting of consecutive elements from the same sequence (sequence_id), I do not want a sequence to have elements belonging to different sequence_id.

例如,没有任何改组,我想得到以下批次:

For instance, without any shuffling, I would like to get the following batches:

  • 第一批:1, 2, 3,
  • 第二批:2, 3, 4,
  • 第三批:3, 4, 5,
  • 第四批:111, 222, 333,
  • 第5批:222, 333, 444,
  • 第6批:333, 444, 555,
  • 第7批:1, 2, 3,

我知道如何使用 tf.data.Dataset.windowtf.data.Dataset.batch 生成序列数据,但我不知道如何防止包含不同 sequence_id 混合的序列(例如序列 4, 5, 111 不应该是有效的,因为它混合了来自序列 0 和序列 1).

I know how to generate sequence data using tf.data.Dataset.window or tf.data.Dataset.batch, but I do not know how to prevent a sequence from containing a mix of different sequence_id (e.g. the sequence 4, 5, 111 should not be valid as it mixes elements from sequence 0 and sequence 1).

以下是我失败的尝试:

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555], 
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True)\
                .repeat(-1)\
                .flat_map(lambda x, y: x.batch(3))\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

输出:

[[  1   2   3]   # good
 [  2   3   4]   # good
 [  3   4   5]   # good
 [  4   5 111]   # bad – mix of sequence 0 (4, 5) and sequence 1 (111)
 [  5 111 222]   # bad
 [111 222 333]   # good
 [222 333 444]   # good
 [333 444 555]   # good
 [  1   2   3]   # good
 [  2   3   4]]  # good

推荐答案

可以使用filter()来判断sequence_id是否一致.因为 filter() 转换目前不支持嵌套数据集作为输入,所以你需要 zip().

You can use filter() to judge if the sequence_id is consistent. Because filter() transformation does not currently support nested datasets as inputs, so you need zip().

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True) \
                .flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
                .filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
                .map(lambda x,y:x)\
                .repeat(-1)\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

[[  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]
 [222 333 444]
 [333 444 555]
 [  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]]

这篇关于使用 tf.data 批量处理顺序数据的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆