ValueError:无法采用未知等级的Shape的长度 [英] ValueError: Cannot take the length of Shape with unknown rank
问题描述
我正在尝试将输入管道移至tensorflow数据集api.为此,我们已将图像和标签转换为tfrecords.然后,我们通过数据集api读取tfrecords,并比较初始数据和读取的数据是否相同.到目前为止,一切都很好.下面是将tfrecords读取到数据集中的代码
I am trying to move our input pipelines to tensorflow dataset api. For that purpose, we have converted images and lables to tfrecords. We are then reading the tfrecords through the dataset api and comparing whether the initial data and the data read are same. So far so good. Below is the code that reads the tfrecords into the dataset
def _parse_function2(proto):
# define your tfrecord again. Remember that you saved your image as a string.
keys_to_features = {"im_path": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
"im_shape": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
"score_shape": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
"geo_shape": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
"im_patches": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
"score_patches": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
"geo_patches": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
}
# Load one example
parsed_features = tf.parse_single_example(serialized=proto, features=keys_to_features)
parsed_features['im_patches'] = parsed_features['im_patches'][0]
parsed_features['score_patches'] = parsed_features['score_patches'][0]
parsed_features['geo_patches'] = parsed_features['geo_patches'][0]
parsed_features['im_patches'] = tf.decode_raw(parsed_features['im_patches'], tf.uint8)
parsed_features['im_patches'] = tf.reshape(parsed_features['im_patches'], parsed_features['im_shape'])
parsed_features['score_patches'] = tf.decode_raw(parsed_features['score_patches'], tf.uint8)
parsed_features['score_patches'] = tf.reshape(parsed_features['score_patches'], parsed_features['score_shape'])
parsed_features['geo_patches'] = tf.decode_raw(parsed_features['geo_patches'], tf.int16)
parsed_features['geo_patches'] = tf.reshape(parsed_features['geo_patches'], parsed_features['geo_shape'])
return parsed_features['im_patches'], tf.cast(parsed_features["score_patches"],tf.int16), parsed_features["geo_patches"]
def create_dataset2(tfrecord_path):
# This works with arrays as well
dataset = tf.data.TFRecordDataset([tfrecord_path], compression_type="ZLIB")
# Maps the parser on every filepath in the array. You can set the number of parallel loaders here
dataset = dataset.map(_parse_function2, num_parallel_calls=8)
# This dataset will go on forever
dataset = dataset.repeat()
# Set the batchsize
dataset = dataset.batch(1)
return dataset
现在,由上述函数创建的数据集将按以下方式传递给model.fit方法.我正在遵循 github要点,该示例给出了如何将数据集传递到模型中的示例.
Now the dataset created by the above function is passed to the model.fit method as follows. I am following the github gist which gives an example of how to pass dataset into the model.fit
train_tfrecord = 'data/tfrecords/train/train.tfrecords'
test_tfrecord = 'data/tfrecords/test/test.tfrecords'
train_dataset = create_dataset2(train_tfrecord)
test_dataset = create_dataset2(test_tfrecord)
model.fit(
train_dataset.make_one_shot_iterator(),
steps_per_epoch=5,
epochs=10,
shuffle=True,
validation_data=test_dataset.make_one_shot_iterator(),
callbacks=[function1, function2, function3],
verbose=1)
但是在上面的model.fit函数调用中出现错误ValueError: Cannot take the length of Shape with unknown rank.
.
But I am getting the error ValueError: Cannot take the length of Shape with unknown rank.
at the model.fit function call above.
我正在使用以下代码迭代数据集,并提取张量的等级,形状和类型.
EDIT 1 : I am using the below code to iterate through the dataset and extract the rank and shape and types of the tensors.
train_tfrecord = 'data/tfrecords/train/train.tfrecords'
with tf.Graph().as_default():
# Deserialize and report on the fake data
sess = tf.Session()
sess.run(tf.global_variables_initializer())
dataset = tf.data.TFRecordDataset([train_tfrecord], compression_type="ZLIB")
dataset = dataset.map(_parse_function2)
iterator = dataset.make_one_shot_iterator()
while True:
try:
next_element = iterator.get_next()
im_patches, score_patches, geo_patches = next_element
rank_im_shape = tf.rank(im_patches)
rank_score_shape = tf.rank(score_patches)
rank_geo_shape = tf.rank(geo_patches)
shape_im_shape = tf.shape(im_patches)
shape_score_shape = tf.shape(score_patches)
shape_geo_shape = tf.shape(geo_patches)
[ some_imshape, some_scoreshape, some_geoshape,\
some_rank_im_shape, some_rank_score_shape, some_rank_geo_shape,
some_shape_im_shape, some_shape_score_shape, some_shape_geo_shape] = \
sess.run([ im_patches, score_patches, geo_patches,
rank_im_shape, rank_score_shape, rank_geo_shape,
shape_im_shape, shape_score_shape, shape_geo_shape])
print("Rank of the 3 patches ")
print(some_rank_im_shape)
print(some_rank_score_shape)
print(some_rank_geo_shape)
print("Shapes of the 3 patches ")
print(some_shape_im_shape)
print(some_shape_score_shape)
print(some_shape_geo_shape)
print("Types of the 3 patches ")
print(type(im_patches))
print(type(score_patches))
print(type(geo_patches))
except tf.errors.OutOfRangeError:
break
下面是这2条tfrecord的输出.
Below is the output of those 2 tfrecords.
Rank of the 3 patches
4
4
4
Shapes of the 3 patches
[ 1 3553 2529 3]
[ 1 3553 2529 2]
[ 1 3553 2529 5]
Types of the 3 patches
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'tensorflow.python.framework.ops.Tensor'>
Rank of the 3 patches
4
4
4
Shapes of the 3 patches
[ 1 3553 5025 3]
[ 1 3553 5025 2]
[ 1 3553 5025 5]
Types of the 3 patches
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'tensorflow.python.framework.ops.Tensor'>
我确实意识到的一件事是,如果我尝试将多个标签作为列表返回并比较上述迭代器脚本的返回值,则会出现错误
One thing I did realize is that if I try to return mulitple labels as a list and compare the returned values from the above iterator script, I get the error
def _parse_function2(proto):
---- everything same as above ----
---- just returning the labels as list---
return parsed_features['im_patches'], [tf.cast(parsed_features["score_patches"],tf.int16), parsed_features["geo_patches"]]
将上述返回值捕获为:
while True:
try:
next_element = iterator.get_next()
im_patches, [score_patches, geo_patches] = next_element
错误如下:TypeError: Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn.
这是拟合函数的定义.看来可能需要tensorflow tensors
和steps_per_epoch
EDIT 2 : Here's the definition of the fit function. It seems it can take tensorflow tensors
as well as steps_per_epoch
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
**kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
Arguments:
x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator. Should return a tuple
of either `(inputs, targets)` or
`(inputs, targets, sample_weights)`.
- A generator or `keras.utils.Sequence` returning `(inputs, targets)`
or `(inputs, targets, sample weights)`.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset, dataset
iterator, generator, or `keras.utils.Sequence` instance, `y` should
not be specified (since targets will be obtained from `x`).
推荐答案
似乎这是tensorflow.keras模块中的错误.在下面的github问题中提出了一个可行的修复程序.
It seems this is a bug in tensorflow.keras module. A fix which works has been suggested in the github issue below.
https://github.com/tensorflow/tensorflow/issues/24520
这篇关于ValueError:无法采用未知等级的Shape的长度的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!