如何正确组合TensorFlow的数据集API和Keras? [英] How to Properly Combine TensorFlow's Dataset API and Keras?

查看:63
本文介绍了如何正确组合TensorFlow的数据集API和Keras?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

Keras的fit_generator()模型方法需要一个生成器,该生成器生成形状(输入,目标)的元组,其中两个元素都是NumPy数组. 文档似乎暗示着,如果我只是将

Keras' fit_generator() model method expects a generator which produces tuples of the shape (input, targets), where both elements are NumPy arrays. The documentation seems to imply that if I simply wrap a Dataset iterator in a generator, and make sure to convert the Tensors to NumPy arrays, I should be good to go. This code, however, gives me an error:

import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from tensorflow.contrib.data import Dataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

with tf.Session() as sess:
    def create_data_generator():
        dat1 = np.arange(4).reshape(-1, 1)
        ds1 = Dataset.from_tensor_slices(dat1).repeat()

        dat2 = np.arange(5, 9).reshape(-1, 1)
        ds2 = Dataset.from_tensor_slices(dat2).repeat()

        ds = Dataset.zip((ds1, ds2)).batch(4)
        iterator = ds.make_one_shot_iterator()
        while True:
            next_val = iterator.get_next()
            yield sess.run(next_val)

datagen = create_data_generator()

input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
                    verbose=2, max_queue_size=2)

这是我得到的错误:

Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__
    fetch, allow_tensor=True, allow_operation=True))
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
    generator_output = next(self._generator)
  File "./datagen_test.py", line 25, in create_data_generator
    yield sess.run(next_val)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)

Traceback (most recent call last):
  File "./datagen_test.py", line 34, in <module>
    verbose=2, max_queue_size=2)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator
    generator_output = next(output_generator)
StopIteration

奇怪的是,在我初始化datagen的位置之后直接添加包含next(datagen)的行会导致代码正常运行,没有错误.

Strangely enough, adding a line containing next(datagen) directly after where I initialize datagen causes the code to run just fine, with no errors.

为什么我的原始代码不起作用?当我在代码中添加该行时,为什么它开始起作用?是否有一种更有效的方式将TensorFlow的Dataset API与Keras结合使用,而不会涉及将Tensors转换为NumPy数组然后再次返回?

Why does my original code not work? Why does it begin to work when I add that line to my code? Is there a more efficient way to use TensorFlow's Dataset API with Keras that doesn't involve converting Tensors to NumPy arrays and back again?

推荐答案

确实有一种更有效的方法来使用Dataset,而无需将张量转换为numpy数组.但是,官方文档上还没有(尚未?).在发行说明中,它是Keras 2.0.7中引入的功能.您可能必须安装keras> = 2.0.7才能使用它.

There is indeed a more efficient way to use Dataset without having to convert the tensors into numpy arrays. However, it is not (yet?) on the official documentation. From the release note, it's a feature introduced in Keras 2.0.7. You may have to install keras>=2.0.7 in order to use it.

x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()

y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()

input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)

几个区别:

  1. tensor参数提供给Input层. Keras将从该张量中读取值,并将其用作输入以拟合模型.
  2. Model.compile()提供target_tensors自变量.
  3. 请记住将x和y都转换为float32.在正常使用情况下,Keras将为您完成此转换.但是现在您必须自己做.
  4. 在构建Dataset时指定批处理大小.使用steps_per_epochepochs控制何时停止模型拟合.
  1. Supply the tensor argument to the Input layer. Keras will read values from this tensor, and use it as the input to fit the model.
  2. Supply the target_tensors argument to Model.compile().
  3. Remember to convert both x and y into float32. Under normal usage, Keras will do this conversion for you. But now you'll have to do it yourself.
  4. Batch size is specified during the construction of Dataset. Use steps_per_epoch and epochs to control when to stop model fitting.

简而言之,如果要从张量读取数据,请使用Input(tensor=...)model.compile(target_tensors=...)model.fit(x=None, y=None, ...).

In short, use Input(tensor=...), model.compile(target_tensors=...) and model.fit(x=None, y=None, ...) if your data are to be read from tensors.

这篇关于如何正确组合TensorFlow的数据集API和Keras?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆