张量流中的索引比收集慢 [英] indexing in tensorflow slower than gather

查看:38
本文介绍了张量流中的索引比收集慢的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试索引张量以从一维张量中获取切片或单个元素.我发现使用 numpy 索引[:]slice vs tf.gather 的方式时有显着的性能差异(几乎 30-40%).

I am trying to index into a tensor to get a slice or single element from 1d tensors. I find that there is significant performance difference when using the numpy way of indexing [:] and slice vs tf.gather (almost 30-40% ).

我还观察到 tf.gather 在用于标量(循环未堆叠张量)时具有显着的开销,而不是 tensor .这是一个已知问题吗?

Also I observe that tf.gather has significant overhead when used on scalars (looping over unstacked tensor) as opposed to tensor . Is this a known issue ?

示例代码(效率低下):

example code (inefficient) :

for node_idxs in graph.nodes():
    node_indice_list = tf.unstack(node_idxs)
    result = []
    for nodeid in node_indices_list:
        x = tf.gather(..., nodeid)
        y = tf.gather(..., nodeid)
        result.append(tf.mul(x,y))
return tf.stack(result)

与示例代码(高效):

for node_idxs in graph.nodes():
    x = tf.gather(..., node_idxs)
    y = tf.gather(..., node_idxs)
return tf.mul(x, y)

我知道第一个低效的实现是做更多的拆栈工作,堆叠然后循环和更多的收集操作,但是当我正在操作的节点顺序是几百个节点(正在拆栈和在单个标量上收集的开销很慢,在第一种情况下,我有更多的收集操作,每个操作都在单个元素上运行,而不是偏移张量).是否有更快的索引方式,我尝试了 numpy 和 slice,结果证明它比 Gather 慢.

I understand that the first inefficient implementation is doing more work of unstacking, stacking and then looping and more gather operations, but i was not expecting 100x slowdown when the order of nodes i am operating on is few hundred nodes (is unstacking and overhead of gather on single scalar that slow, in first case i have many more gather operation each operating on single element as opposed to tensor of offsets) . Are there faster way of indexing , i tried numpy and slice which turned out to be slower than gather.

推荐答案

首先,代码并没有真正比较 Gather 和 Numpy 索引——它比较了矢量化索引 (tf.gather) 和循环索引(Python for"循环)).循环很慢也就不足为奇了.

First, the code doesn't really compare gather vs Numpy indexing - it compares vectorized indexing (tf.gather) vs looped indexing (Python "for" loop). No surprise that looping is slow.

请注意,类似 Numpy 的索引 tensor[idxs] 在 Tensorflow 中无论如何都受到限制:

Note that Numpy-like indexing tensor[idxs] is anyway restricted in Tensorflow:

仅整数、切片 (:)、省略号 (...)、tf.newaxis (None) 和标量 tf.int32/tf.int64 张量是有效索引

Only integers, slices (:), ellipsis (...), tf.newaxis (None) and scalar tf.int32/tf.int64 tensors are valid indices

所以将 tf.gather 用于一般应用.

So use tf.gather for general applications.

这篇关于张量流中的索引比收集慢的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆