如何将基于Feed的基本TensorFlow代码转换为使用“数据集"? [英] How do I convert my basic feed-based TensorFlow code to use 'Dataset'?

查看:122
本文介绍了如何将基于Feed的基本TensorFlow代码转换为使用“数据集"?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

了解其中的优点(尤其是当我扩展所构建模型的范围以及模型的大小时他们使用的数据集)以使用TensorFlow的新 Dataset 作为数据馈送管道的惯用法.但是我在将现有的基于feed_dict的代码映射到该新模型时遇到了麻烦.

I understand that there are advantages (especially as I expand the scope of the models I build and the size of the datasets they work on) to using TensorFlow's new Dataset as the idiom for my data feeding pipeline. However I'm having trouble mapping my existing feed_dict based code to this new model.

我面临的一个问题是我无法弄清批处理和历元如何交互,或者它们如何与我经常进行的日志记录和验证交织.

One problem I face is that I can't sort out how batching and epochs interact, or how these interleave with the logging and validation that I often do.

例如,以下内容如何映射为使用Dataset?

For example, how does something like the following map to using Dataset?

# Load and process data into tensors of dimension (N, C_i) for input and (N, C_o) for output
# where N is the number of examples and C_ is the number of chanels, and the values are activations
train_x, train_y, valid_x, valid_y = load_data(file, [segments], ...)
train_size = len(train_x)

train_stats_feed = {input_activation: train_x, correct_output: train_y, is_train: False}
valid_stats_feed = {input_activation: valid_x, correct_output: valid_y, is_train: False}

with tf.Session(config=tf.ConfigProto(...)) as sess:
    sess.run(tf.initialize_all_variables())

    # Some analysis; not always done but the code needs to support it
    train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), 0)
    test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), 0)

    test_writer.add_summary(sess.run(gs_summary), 0)

    print(log_fmt.format(0, float(sess.run(accuracy, feed_dict=valid_stats_feed)),
                         float(sess.run(loss, feed_dict=valid_stats_feed))))

    for ep in range(epochs):
        # Slice the training data into random batches
        batch_indices = np.array_split(np.random.permutation(train_size), int(train_size/mb_size))

        for mini_batch_indices in batch_indices:
            sess.run(train_step, feed_dict={input_activation: train_x[mini_batch_indices],
                                            correct_output: train_y[mini_batch_indices], is_train: True})

            gs = int(sess.run(global_step))
            if gs % log_steps == 0:
                test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), gs)
                train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), gs)

                acc = float(sess.run(accuracy, feed_dict=valid_stats_feed))
                sess.run(validation_accuracy.assign(acc))

                print(log_fmt.format(gs, acc, float(sess.run(loss, feed_dict=valid_stats_feed))))

        print(ep_fmt.format(ep + 2))
        test_writer.add_summary(sess.run(gs_summary), ep + 1)


如果需要,可以对上述内容进行一些不太明显的定义:


Some of the less obvious definitions for the above, if needed:

# Preliminaries

# Some basic preliminaries, the details of which are not important to the question
# Mostly pretty standard; obvious things omitted from MWE for brevity
global_step = tf.Variable(0, trainable=False, name='global_step')
validation_accuracy = tf.Variable(0.0, trainable=False, name='validation_accuracy', dtype=tf.float32)

is_train = tf.placeholder(tf.bool, [], name='is_train')
input_activation = tf.placeholder(tf.float32, shape=[None, in_nodes], name='inputs')
correct_output = tf.placeholder(tf.float32, shape=[None, out_nodes], name='correct_outputs')

network_output = tf.identity(out_activations)
correct_predictions = correct_fn(correct_output, network_output)
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
error = cost_fn(correct_output, network_output)
loss = error + FLAGS.regularization_weight * sum(tf.nn.l2_loss(w) for w in layer_weights)

train_step = tf.train.MomentumOptimizer(learning_rate, momentum=momentum).minimize(loss, global_step=global_step)

# Logging
train_writer = tf.summary.FileWriter(trainlogfile, tf.get_default_graph())
test_writer = tf.summary.FileWriter(testlogfile, tf.get_default_graph())
gs_summary = tf.summary.scalar('global_step_at_epoch', global_step)
merged = tf.summary.merge_all()

推荐答案

以下是几行入门培训.相同的逻辑适用于验证

Here're few lines for training to get started. Same logics apply for validation

# Define placeholder for inputs data and labels
inputs_placeholder = tf.placeholder(train_x.dtype, train_x.shape)
labels_placeholder = tf.placeholder(train_y.dtype, train_y.shape)
# Define a Dataset object using the above placeholders
dataset = tf.contrib.data.Dataset.from_tensor_slices((inputs_placeholder,      labels_placeholder))
# Define batch_size
batch_size = 128
dataset = dataset.batch(batch_size)
# Define iterator
iterator = dataset.make_initializable_iterator()
# Get one batch
next_example, next_label = iterator.get_next()
# calculate loss from the model fucntion you are using
loss = some_model(next_example, next_label)
# Set number of Epochs here
num_epochs = 100
for _ in range(num_epochs):
    sess.run(iterator.initializer, feed_dict={inputs_placeholder: train_x, labels_placeholder: train_y}))
    while True:
        try:
            _loss = sess.run(loss)
        except tf.errors.OutOfRangeError:
            break

这篇关于如何将基于Feed的基本TensorFlow代码转换为使用“数据集"?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆