Tensorflow中张量的逐行处理 [英] Row by row processing on tensor in Tensorflow

查看:72
本文介绍了Tensorflow中张量的逐行处理的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在测试代码以逐行处理张量.

I am testing the code to process row by row of a tensor.

张量可能有一行,其中最后 4 个元素为 0 或非零值.

The tensor may have a row with the last 4 elements are 0 or with non-zero values.

如果该行的最后 4 个元素为 0 [1.0,2.0,2.3,3.4,0,0,0,0]删除最后四个元素,并将形状更改为行中的 5 个元素.第一个元素表示行索引.它变成了 [0.0,1.0,2.0,2.3,3.4].

If the row has 0 for the last 4 elements [1.0,2.0,2.3,3.4,0,0,0,0] The last four are removed and shape is changed to 5 elements in the row. The first element represents the row index. It becomes like [0.0,1.0,2.0,2.3,3.4].

如果该行的所有 8 个元素都具有非零值,则拆分为两行并将行索引放在首位.然后 [3.0,4.0,1.0,2.1,1.2,1.4,1.2,1.5] 变成了 [[2.0,3.0,4.0,1.0,2.1],[2.0,1.2,1.4,1.2,1.5]].第一个元素 2.0 是张量中的行索引.

If the row has all 8 elements with non-zero values, then split into two rows and put row index in the first place. Then [3.0,4.0,1.0,2.1,1.2,1.4,1.2,1.5] becomes like [[2.0,3.0,4.0,1.0,2.1],[2.0,1.2,1.4,1.2,1.5]]. The first element 2.0 is row index in tensor.

所以经过处理<代码>[[1.0,2.0,2.3,3.4,0,0,0,0],[2.0,3.2,4.2,4.0,0,0,0,0],[3.0,4.0,1.0,2.1,1.2,1.4,1.2,1.5],[1.2,1.3,3.4,4.5,1,2,3,4]] 变成

[[0,1.0,2.0,2.3,3.4],[1.0,2.0,3.2,4.2,4.0],[2.0,3.0,4.0,1.0,2.1],[2.0,1.2,1.4,1.2,1.5],[3.0,1.2,1.3,3.4,4.5],[3.0,1,2,3,4]]

我做了如下.但是 error as TypeError: TypeErro...pected',) at map_fn.

    import tensorflow as tf

    boxes = tf.constant([[1.0,2.0,2.3,3.4,0,0,0,0],[2.0,3.2,4.2,4.0,0,0,0,0],[3.0,4.0,1.0,2.1,1.2,1.4,1.2,1.5],[1.2,1.3,3.4,4.5,1,2,3,4]])
    rows = tf.expand_dims(tf.range(tf.shape(boxes)[0], dtype=tf.int32), 1)
    def bbox_organize(box, i):
       if(tf.reduce_sum(box[4:]) == 0):
          box=tf.squeeze(box, [5,6,7]
          box=tf.roll(box, shift=1, axis=0)
          box[0]=i
       else:
          box=tf.reshape(box, [2, 4])
          const_=tf.constant(i, shape=[2, 1])
          box=tf.concat([const_, box], 0)
       return box
    b_boxes= tf.map_fn(lambda x: (bbox_organize(x[0], x[1]), x[1]), (boxes, rows), dtype=(tf.int32, tf.int32))
    with tf.Session() as sess: print(sess.run(b_boxes))

我不擅长 Tensorflow,仍在学习中.

I am not good at Tensorflow and still learning.

有没有更好的方法来实现 Tensorflow api 来处理它?<​​/p>

Is there better way of implementing Tensorflow apis to process it?

推荐答案

你需要做的是变换boxes的形状,得到非零线

What you need to do is transform the shape of boxes and get non-zero lines

import tensorflow as tf
tf.enable_eager_execution();

boxes = tf.constant([[1.0,2.0,2.3,3.4,0,0,0,0],[2.0,3.2,4.2,4.0,0,0,0,0],[3.0,4.0,1.0,2.1,1.2,1.4,1.2,1.5],[1.2,1.3,3.4,4.5,1,2,3,4]])

boxes = tf.reshape(boxes,[tf.shape(boxes)[0]*2,-1])
# [[1.  2.  2.3 3.4]
#  [0.  0.  0.  0. ]
#  [2.  3.2 4.2 4. ]
#  [0.  0.  0.  0. ]
#  [3.  4.  1.  2.1]
#  [1.2 1.4 1.2 1.5]
#  [1.2 1.3 3.4 4.5]
#  [1.  2.  3.  4. ]]

rows = tf.floor(tf.expand_dims(tf.range(tf.shape(boxes)[0]), 1)/2)
# [[0.]
#  [0.]
#  [1.]
#  [1.]
#  [2.]
#  [2.]
#  [3.]
#  [3.]]

add_index = tf.concat([tf.cast(rows,tf.float32),boxes],-1)

index = tf.not_equal(tf.reduce_sum(add_index[:,4:],axis=1),0)
# [ True False  True False  True  True  True  True]

boxes_ = tf.gather_nd(add_index,tf.where(index))
print(boxes_)

# tf.Tensor(
# [[0.  1.  2.  2.3 3.4]
#  [1.  2.  3.2 4.2 4. ]
#  [2.  3.  4.  1.  2.1]
#  [2.  1.2 1.4 1.2 1.5]
#  [3.  1.2 1.3 3.4 4.5]
#  [3.  1.  2.  3.  4. ]], shape=(6, 5), dtype=float32)

我觉得你更需要上面的代码.矢量化方法将明显快于 tf.map_fn().

I think you need the above code more. The vectorization method will be significantly faster than the tf.map_fn().

这篇关于Tensorflow中张量的逐行处理的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆