如何使用TFRecord数据集快速提高TensorFlow + Keras? [英] How do you make TensorFlow + Keras fast with a TFRecord dataset?

查看:79
本文介绍了如何使用TFRecord数据集快速提高TensorFlow + Keras?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何使用带有Keras模型和tf.session.run()的TensorFlow TFRecord同时将数据集保留在具有队列运行器的张量中的示例?

下面是一个有效的代码段,但需要进行以下改进:

Below is a snippet that works but it needs the following improvements:

  • 使用模型API
  • 指定一个Input()
  • 从TFRecord加载数据集
  • 并行运行数据集(例如使用队列运行器)
  • Use the Model API
  • specify an Input()
  • Load a dataset from a TFRecord
  • Run through a dataset in parallel (such as with a queuerunner)

这是代码段,有几条TODO线指示所需的内容:

Here is the snippet, there are several TODO lines indicating what is needed:

from keras.models import Model
import tensorflow as tf
from keras import backend as K
from keras.layers import Dense, Input
from keras.objectives import categorical_crossentropy
from tensorflow.examples.tutorials.mnist import input_data

sess = tf.Session()
K.set_session(sess)

# Can this be done more efficiently than placeholders w/ TFRecords?
img = tf.placeholder(tf.float32, shape=(None, 784))
labels = tf.placeholder(tf.float32, shape=(None, 10))

# TODO: Use Input() 
x = Dense(128, activation='relu')(img)
x = Dense(128, activation='relu')(x)
preds = Dense(10, activation='softmax')(x)
# TODO: Construct model = Model(input=inputs, output=preds)

loss = tf.reduce_mean(categorical_crossentropy(labels, preds))

# TODO: handle TFRecord data, is it the same?
mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

sess.run(tf.global_variables_initializer())

# TODO remove default, add queuerunner
with sess.as_default():
    for i in range(1000):
        batch = mnist_data.train.next_batch(50)
        train_step.run(feed_dict={img: batch[0],
                                  labels: batch[1]})
    print(loss.eval(feed_dict={img:    mnist_data.test.images, 
                               labels: mnist_data.test.labels}))

这个问题为什么与之相关?

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆