如何使用DataSet API在Tensorflow中为tf.train.SequenceExample数据创建填充批处理? [英] How do I create padded batches in Tensorflow for tf.train.SequenceExample data using the DataSet API?

查看:727
本文介绍了如何使用DataSet API在Tensorflow中为tf.train.SequenceExample数据创建填充批处理?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

为了在 Tensorflow 中训练 LSTM模型,我将数据结构化为 tf.train.SequenceExample 格式,并将其存储为 TFRecord文件.我现在想使用新的DataSet API来生成填充批次进行培训.在文档中,有一个使用papped_batch的示例,但是对于我的数据我无法弄清楚 padded_shapes 的值应该是什么.

For training an LSTM model in Tensorflow, I have structured my data into a tf.train.SequenceExample format and stored it into a TFRecord file. I would now like to use the new DataSet API to generate padded batches for training. In the documentation there is an example for using padded_batch, but for my data I can't figure out what the value of padded_shapes should be.

为了将TFrecord文件读入批处理,我编写了以下Python代码:

For reading the TFrecord file into the batches I have written the following Python code:

import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array

if(len(sys.argv) != 2):
  print "Usage: createbatches.py [RFRecord file]"
  sys.exit(0)


vectorSize = 40
inFile = sys.argv[1]

def parse_function_dataset(example_proto):
  sequence_features = {
      'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
                                           dtype=tf.float32),
      'labels': tf.FixedLenSequenceFeature(shape=[],
                                           dtype=tf.int64)}

  _, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)

  length = tf.shape(sequence['inputs'])[0]
  return sequence['inputs'], sequence['labels']

sess = tf.InteractiveSession()

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_function_dataset)
# dataset = dataset.batch(1)
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_initializable_iterator()

batch = iterator.get_next()

# Initialize `iterator` with training data.
training_filenames = [inFile]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

print(sess.run(batch))

如果我使用dataset = dataset.batch(1)(在这种情况下无需填充),代码可以很好地工作,但是当我使用padded_batch变体时,出现以下错误:

The code works well if I use dataset = dataset.batch(1) (no padding needed in that case), but when I use the padded_batch variant, I get the following error:

TypeError:如果浅层结构是一个序列,则输入也必须是一个 顺序.输入的类型为:.

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: .

您能帮我弄清楚我应该为 pappedd_shapes 参数传递什么吗?

Can you help me figuring out what I should pass for the padded_shapes parameter?

(我知道有很多使用线程和队列的示例代码,但是我更愿意为此项目使用新的DataSet API)

(I know there is lots of example code using threading and queues for this, but I'd rather use the new DataSet API for this project)

推荐答案

您需要传递一个形状的元组. 就您而言,您应该通过

You need to pass a tuple of shapes. In your case you should pass

dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None]))

或尝试

dataset = dataset.padded_batch(4, padded_shapes=([None],[None]))

为此代码进行检查更多细节.我必须调试此方法以弄清楚为什么它对我不起作用.

Check this code for more details. I had to debug this method to figure out why it wasn't working for me.

这篇关于如何使用DataSet API在Tensorflow中为tf.train.SequenceExample数据创建填充批处理?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆