张量流中的索引比收集慢 [英] indexing in tensorflow slower than gather
问题描述
我正在尝试索引张量以从一维张量中获取切片或单个元素.我发现使用 numpy
索引[:]
和 slice vs tf.gather
的方式时有显着的性能差异(几乎 30-40%).
I am trying to index into a tensor to get a slice or single element from 1d tensors. I find that there is significant performance difference when using the numpy
way of indexing [:]
and slice vs tf.gather
(almost 30-40% ).
我还观察到 tf.gather
在用于标量(循环未堆叠张量)时具有显着的开销,而不是 tensor .这是一个已知问题吗?
Also I observe that tf.gather
has significant overhead when used on scalars (looping over unstacked tensor) as opposed to tensor . Is this a known issue ?
示例代码(效率低下):
example code (inefficient) :
for node_idxs in graph.nodes():
node_indice_list = tf.unstack(node_idxs)
result = []
for nodeid in node_indices_list:
x = tf.gather(..., nodeid)
y = tf.gather(..., nodeid)
result.append(tf.mul(x,y))
return tf.stack(result)
与示例代码(高效):
for node_idxs in graph.nodes():
x = tf.gather(..., node_idxs)
y = tf.gather(..., node_idxs)
return tf.mul(x, y)
我知道第一个低效的实现是做更多的拆栈工作,堆叠然后循环和更多的收集操作,但是当我正在操作的节点顺序是几百个节点(正在拆栈和在单个标量上收集的开销很慢,在第一种情况下,我有更多的收集操作,每个操作都在单个元素上运行,而不是偏移张量).是否有更快的索引方式,我尝试了 numpy 和 slice,结果证明它比 Gather 慢.
I understand that the first inefficient implementation is doing more work of unstacking, stacking and then looping and more gather operations, but i was not expecting 100x slowdown when the order of nodes i am operating on is few hundred nodes (is unstacking and overhead of gather on single scalar that slow, in first case i have many more gather operation each operating on single element as opposed to tensor of offsets) . Are there faster way of indexing , i tried numpy and slice which turned out to be slower than gather.
推荐答案
首先,代码并没有真正比较 Gather 和 Numpy 索引——它比较了矢量化索引 (tf.gather) 和循环索引(Python for"循环)).循环很慢也就不足为奇了.
First, the code doesn't really compare gather vs Numpy indexing - it compares vectorized indexing (tf.gather) vs looped indexing (Python "for" loop). No surprise that looping is slow.
请注意,类似 Numpy 的索引 tensor[idxs]
在 Tensorflow 中无论如何都受到限制:
Note that Numpy-like indexing tensor[idxs]
is anyway restricted in Tensorflow:
仅整数、切片 (:
)、省略号 (...
)、tf.newaxis (None
) 和标量 tf.int32/tf.int64 张量是有效索引
Only integers, slices (
:
), ellipsis (...
), tf.newaxis (None
) and scalar tf.int32/tf.int64 tensors are valid indices
所以将 tf.gather
用于一般应用.
So use tf.gather
for general applications.
这篇关于张量流中的索引比收集慢的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!