展平TensorFlow中的数据集 [英] Flatten a dataset in TensorFlow
问题描述
我试图将TensorFlow中的数据集转换为具有多个单值张量.数据集当前如下所示:
I am trying to convert a dataset in TensorFlow to have several single-valued tensors. The dataset currently looks like this:
[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ...
转换后,它应如下所示:
After the transformation it should look like this:
[12] [43] [64] [34] [45] [2] [13] [54] [34] [65] [34] [67] [87] [12] ...
我最初的想法是在数据集上使用flat_map
,然后使用reshape
和unstack
将每个张量转换为张量列表:
My initial idea was using flat_map
on the data set and then converting each tensor to a list of tensors using reshape
and unstack
:
output_labels = self.dataset.flat_map(convert_labels)
...
def convert_labels(tensor):
id_list = tf.unstack(tf.reshape(tensor, [-1, 1]))
return tf.data.Dataset.from_tensors(id_list)
但是每个张量的形状仅是部分已知的(即(?, 1)
),这就是为什么unstack操作失败的原因.有没有办法在没有明确迭代的情况下仍然连接"不同的张量?
However the shape of each tensor is only partially known (i.e. (?, 1)
) which is why the unstack operation fails. Is there any way to still "concat" the different tensors without explicitly iterating over them?
推荐答案
您的解决方案非常接近,但是 Dataset.from_tensor_slices()
方法完全适合您的使用情况,因为它可以将张量拆分为可变数量的元素:
Your solution is very close, but Dataset.flat_map()
takes a function that returns a tf.data.Dataset
object, rather than a list of tensors. Fortunately, the Dataset.from_tensor_slices()
method works for exactly your use case, because it can split a tensor into a variable number of elements:
output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)
请注意, tf.contrib.data.unbatch()
转换会实现相同的功能,并且在TensorFlow的当前主分支(将包含在1.9版本中)中实现的效率略高:
Note that the tf.contrib.data.unbatch()
transformation implements the same functionality, and has a slightly more efficient implementation in the current master branch of TensorFlow (will be included in the 1.9 release):
output_labels = self.dataset.apply(tf.contrib.data.unbatch())
这篇关于展平TensorFlow中的数据集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!