如何正确组合TensorFlow的数据集API和Keras? [英] How to Properly Combine TensorFlow's Dataset API and Keras?
问题描述
Keras的fit_generator()
模型方法需要一个生成器,该生成器生成形状(输入,目标)的元组,其中两个元素都是NumPy数组. 文档似乎暗示着,如果我只是将
Keras' fit_generator()
model method expects a generator which produces tuples of the shape (input, targets), where both elements are NumPy arrays. The documentation seems to imply that if I simply wrap a Dataset
iterator in a generator, and make sure to convert the Tensors to NumPy arrays, I should be good to go. This code, however, gives me an error:
import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from tensorflow.contrib.data import Dataset
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
with tf.Session() as sess:
def create_data_generator():
dat1 = np.arange(4).reshape(-1, 1)
ds1 = Dataset.from_tensor_slices(dat1).repeat()
dat2 = np.arange(5, 9).reshape(-1, 1)
ds2 = Dataset.from_tensor_slices(dat2).repeat()
ds = Dataset.zip((ds1, ds2)).batch(4)
iterator = ds.make_one_shot_iterator()
while True:
next_val = iterator.get_next()
yield sess.run(next_val)
datagen = create_data_generator()
input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
verbose=2, max_queue_size=2)
这是我得到的错误:
Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__
fetch, allow_tensor=True, allow_operation=True))
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
generator_output = next(self._generator)
File "./datagen_test.py", line 25, in create_data_generator
yield sess.run(next_val)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch
return _ListFetchMapper(fetch)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch
return _ElementFetchMapper(fetches, contraction_fn)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__
'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)
Traceback (most recent call last):
File "./datagen_test.py", line 34, in <module>
verbose=2, max_queue_size=2)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator
generator_output = next(output_generator)
StopIteration
奇怪的是,在我初始化datagen
的位置之后直接添加包含next(datagen)
的行会导致代码正常运行,没有错误.
Strangely enough, adding a line containing next(datagen)
directly after where I initialize datagen
causes the code to run just fine, with no errors.
为什么我的原始代码不起作用?当我在代码中添加该行时,为什么它开始起作用?是否有一种更有效的方式将TensorFlow的Dataset API与Keras结合使用,而不会涉及将Tensors转换为NumPy数组然后再次返回?
Why does my original code not work? Why does it begin to work when I add that line to my code? Is there a more efficient way to use TensorFlow's Dataset API with Keras that doesn't involve converting Tensors to NumPy arrays and back again?
推荐答案
确实有一种更有效的方法来使用Dataset
,而无需将张量转换为numpy数组.但是,官方文档上还没有(尚未?).在发行说明中,它是Keras 2.0.7中引入的功能.您可能必须安装keras> = 2.0.7才能使用它.
There is indeed a more efficient way to use Dataset
without having to convert the tensors into numpy arrays. However, it is not (yet?) on the official documentation. From the release note, it's a feature introduced in Keras 2.0.7. You may have to install keras>=2.0.7 in order to use it.
x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()
y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()
input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)
几个区别:
- 将
tensor
参数提供给Input
层. Keras将从该张量中读取值,并将其用作输入以拟合模型. - 为
Model.compile()
提供target_tensors
自变量. - 请记住将x和y都转换为
float32
.在正常使用情况下,Keras将为您完成此转换.但是现在您必须自己做. - 在构建
Dataset
时指定批处理大小.使用steps_per_epoch
和epochs
控制何时停止模型拟合.
- Supply the
tensor
argument to theInput
layer. Keras will read values from this tensor, and use it as the input to fit the model. - Supply the
target_tensors
argument toModel.compile()
. - Remember to convert both x and y into
float32
. Under normal usage, Keras will do this conversion for you. But now you'll have to do it yourself. - Batch size is specified during the construction of
Dataset
. Usesteps_per_epoch
andepochs
to control when to stop model fitting.
简而言之,如果要从张量读取数据,请使用Input(tensor=...)
,model.compile(target_tensors=...)
和model.fit(x=None, y=None, ...)
.
In short, use Input(tensor=...)
, model.compile(target_tensors=...)
and model.fit(x=None, y=None, ...)
if your data are to be read from tensors.
这篇关于如何正确组合TensorFlow的数据集API和Keras?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!