具有Tensorflow Dataset API的Keras自动编码器并记录到Tensorboard [英] Keras autoencoder with Tensorflow Dataset API and logging to Tensorboard

查看:44
本文介绍了具有Tensorflow Dataset API的Keras自动编码器并记录到Tensorboard的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在Keras中有一个简单的自动编码器,我想使用日志记录到张量板(因此我需要传递验证数据),并使用Tensorflow Dataset API通过预取从TFRecord加载数据.我读了一些有关它的文章,但他们要么省略了验证管道,要么直接传递了没有提要dict的数据会大大降低速度.

I have simple autoencoder in Keras, I want to use logging to tensorboard (thus I need passing validation data), and load the data from TFRecord using the Tensorflow Dataset API using prefetch. I read some articles about it, but they either omitted validation pipeline, or the fact that passing the data directly without feed dict is significantly slower.

源代码是

import tensorflow as tf
from keras.losses import mean_squared_error
from keras.models import Sequential, Model
from keras.layers import Dense, Input, Flatten, Reshape, Convolution2D,     Convolution2DTranspose, Conv2D, Conv2DTranspose
from keras.optimizers import Adam
from keras import backend as K
from keras.callbacks import TensorBoard

def create_dataset(tf_record, batch_size):
    data = tf.data.TFRecordDataset(tf_record)
    data = data.map(TFReader._parse_example_encoded, num_parallel_calls=8)
    data = data.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=100))
    data = data.batch(batch_size, drop_remainder=True)
    data = data.prefetch(4)
    return data


def main(_):
    batch_size = 8  # todo: check and try bigger
    data = create_dataset('../../datasets/anime/no-game-no-life-ep-2.tfrecord', batch_size)
    iterator = data.make_one_shot_iterator()

    K.set_image_data_format('channels_last')  # set format

    input_tensor = Input(tensor=iterator.get_next())
    out = Conv2D(8, (3, 3), activation='elu', border_mode='valid', batch_input_shape=(batch_size, 432, 768, 3))(input_tensor)
    out = Conv2D(16, (3, 3), activation='elu', border_mode='valid')(out)
    out = Conv2D(32, (3, 3), activation='elu', border_mode='valid', name='bottleneck')(out)
    out = Conv2DTranspose(32, (3, 3), activation='elu', padding='valid')(out)
    out = Conv2DTranspose(16, (3, 3), activation='elu', padding='valid')(out)
    out = Conv2DTranspose(8, (3, 3), activation='elu', padding='valid')(out)
    out = Conv2D(3, (3, 3), activation='elu', padding='same')(out)
    m = Model(inputs=input_tensor, outputs=out)
    m.compile(loss=mean_squared_error, optimizer=Adam(), target_tensors=iterator.get_next())
    print(m.summary())
    tensorboard = TensorBoard(
        log_dir='logs/anime', histogram_freq=5, embeddings_freq=5, embeddings_layer_names=['bottleneck'],
    write_images=True, embeddings_data=iterator.get_next(), embeddings_metadata='embeddings.tsv')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    set_session(tf.Session(config=config))

    history = m.fit(steps_per_epoch=100, epochs=50, verbose=1,
                validation_data=(iterator.get_next(), iterator.get_next()),
                validation_steps=4,
                callbacks=[tensorboard]
                )


if __name__ == '__main__':
    tf.app.run()

训练本身开始,第一个纪元训练,但是随后在验证过程中失败,

The training itself starts, the first epoch trains, but then it fails during validation by

File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\pydevd.py", line 1741, in <module>
main()
  File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\pydevd.py", line 1735, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\pydevd.py", line 1135, in run
pydev_imports.execfile(file, globals, locals)  # execute the script
File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "E:/Projects/anime-style-transfer/code/neural_style_transfer/anime_dimension_reduction_keras.py", line 95, in <module>
tf.app.run()
File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
File "E:/Projects/anime-style-transfer/code/neural_style_transfer/anime_dimension_reduction_keras.py", line 78, in main
callbacks=[tensorboard]
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1039, in fit
validation_steps=validation_steps)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training_arrays.py", line 217, in fit_loop
callbacks.on_epoch_end(epoch, epoch_logs)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks.py", line 79, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks.py", line 912, in on_epoch_end
raise ValueError("If printing histograms, validation_data must be "
ValueError: If printing histograms, validation_data must be provided, and cannot be a generator.

我认为问题出在验证数据的传递上,因为它直接使用了训练tfrecord的输入张量.

And I assume the problem is somewhere with passing the validation data, because it uses directly the input tensor from training tfrecord.

尽管我不需要单独的训练和验证数据,所以如果有什么办法告诉Keras,它可以在相同的输入上进行验证,那么只要我获得TensorBoard日志就可以了.

Although I don't need separate training and validation data, so if there would be any way to tell Keras that it can validate on the same inputs, it would be fine as long as I get my TensorBoard logs.

推荐答案

少数选项:

  1. 您是否已查看此链接 https://github.com/keras-team/keras/issues/3358 (通过juiceboxjoe解决)?
    编写一个TensorboardWrapper,它从生成器加载验证数据,并将其作为回调传递.
  2. 如果您不关心验证,请从训练数据中加载一个或两个样本,然后将它们作为数组传递给validation_data.
  3. 如果不需要直方图,则将histogram_freq设置为0.
  1. Have you looked at this link https://github.com/keras-team/keras/issues/3358 (solution by juiceboxjoe)?
    Write a TensorboardWrapper which loads the validation data from the generator and pass that as the callback.
  2. If you don't care about validation, load a sample or two from the training data and pass them as arrays to validation_data.
  3. Set histogram_freq = 0, if Histograms are not needed.

这篇关于具有Tensorflow Dataset API的Keras自动编码器并记录到Tensorboard的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆