Tensorflow keras与TF数据集输入 [英] Tensorflow keras with tf dataset input

查看:208
本文介绍了Tensorflow keras与TF数据集输入的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是Tensorflow keras和数据集的新手.谁能帮助我了解为什么以下代码不起作用?

I'm new to tensorflow keras and dataset. Can anyone help me understand why the following code doesn't work?

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.keras.utils import multi_gpu_model
from tensorflow.python.keras import backend as K


data = np.random.random((1000,32))
labels = np.random.random((1000,10))
dataset = tf.data.Dataset.from_tensor_slices((data,labels))
print( dataset)
print( dataset.output_types)
print( dataset.output_shapes)
dataset.batch(10)
dataset.repeat(100)

inputs = keras.Input(shape=(32,))  # Returns a placeholder tensor

# A layer instance is callable on a tensor, and returns a tensor.
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(64, activation='relu')(x)
predictions = keras.layers.Dense(10, activation='softmax')(x)

# Instantiate the model given inputs and outputs.
model = keras.Model(inputs=inputs, outputs=predictions)

# The compile step specifies the training configuration.
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
          loss='categorical_crossentropy',
          metrics=['accuracy'])

# Trains for 5 epochs
model.fit(dataset, epochs=5, steps_per_epoch=100)

它失败,并出现以下错误:

It failed with the following error:

model.fit(x=dataset, y=None, epochs=5, steps_per_epoch=100)
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 1510, in fit
validation_split=validation_split)
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 994, in _standardize_user_data
class_weight, batch_size)
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 1113, in _standardize_weights
exception_prefix='input')
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training_utils.py", line 325, in standardize_input_data
'with shape ' + str(data_shape))
ValueError: Error when checking input: expected input_1 to have 2 dimensions, but got array with shape (32,)

根据tf.keras指南,我应该能够将数据集直接传递给model.fit,如以下示例所示:

According to tf.keras guide, I should be able to directly pass the dataset to model.fit, as this example shows:

输入tf.data数据集

使用Datasets API扩展到大型数据集或多设备培训.将tf.data.Dataset实例传递给fit方法:

Input tf.data datasets

Use the Datasets API to scale to large datasets or multi-device training. Pass a tf.data.Dataset instance to the fit method:

# Instantiates a toy dataset instance:
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.repeat()

在数据集上调用fit时,请不要忘记指定steps_per_epoch.

model.fit(数据集,epochs = 10,steps_per_epoch = 30) 在这里,fit方法使用steps_per_epoch参数-这是模型在移至下一个纪元之前运行的训练步数.由于数据集会产生一批数据,因此此代码段不需要batch_size.

Don't forget to specify steps_per_epoch when calling fit on a dataset.

model.fit(dataset, epochs=10, steps_per_epoch=30) Here, the fit method uses the steps_per_epoch argument—this is the number of training steps the model runs before it moves to the next epoch. Since the Dataset yields batches of data, this snippet does not require a batch_size.

数据集也可以用于验证:

Datasets can also be used for validation:

dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32).repeat()

val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
val_dataset = val_dataset.batch(32).repeat()

model.fit(dataset, epochs=10, steps_per_epoch=30,
      validation_data=val_dataset,
      validation_steps=3)

我的代码有什么问题,正确的方法是什么?

What's the problem with my code, and what's the correct way of doing it?

推荐答案

关于您为什么会收到错误的原始问题:

To your original question as to why you're getting the error:

Error when checking input: expected input_1 to have 2 dimensions, but got array with shape (32,)

代码中断的原因是因为您没有将.batch()重新应用到dataset变量,就像这样:

The reason your code breaks is because you haven't applied the .batch() back to the dataset variable, like so:

dataset = dataset.batch(10)

您只是打了dataset.batch().

这会中断,因为没有batch()的输出张量不会被批处理,即您得到的形状是(32,)而不是(1,32).

This breaks because without the batch() the output tensors are not batched, i.e. you get shape (32,) instead of (1,32).

这篇关于Tensorflow keras与TF数据集输入的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆