批量 4D 张量 Tensorflow 索引 [英] Batched 4D tensor Tensorflow indexing

查看:45
本文介绍了批量 4D 张量 Tensorflow 索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

给定

  • batch_images:形状 (B, H, W, C)
  • 的 4D 张量
  • x:形状 (B, H, W)
  • 的 3D 张量
  • y:形状为(B, H, W)
  • 的3D张量
  • batch_images: 4D tensor of shape (B, H, W, C)
  • x: 3D tensor of shape (B, H, W)
  • y: 3D tensor of shape (B, H, W)

目标

如何使用 xy 坐标索引到 batch_images 以获得形状为 B、H、W,C.也就是说,我想为每个批次以及每对 (x, y) 获得形状为 C 的张量.

How can I index into batch_images using the x and y coordinates to obtain a 4D tensor of shape B, H, W, C. That is, I want to obtain for each batch, and for each pair (x, y) a tensor of shape C.

在 numpy 中,这可以使用 input_img[np.arange(B)[:,None,None], y, x] 来实现,例如,但我似乎无法让它工作在张量流中.

In numpy, this would be achieved using input_img[np.arange(B)[:,None,None], y, x] for example but I can't seem to make it work in tensorflow.

我目前的尝试

def get_pixel_value(img, x, y):
    """
    Utility function to get pixel value for 
    coordinate vectors x and y from a  4D tensor image.
    """
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    C = tf.shape(img)[3]

    # flatten image
    img_flat = tf.reshape(img, [-1, C])

    # flatten idx
    idx_flat = (x*W) + y

    return tf.gather(img_flat, idx_flat)

返回一个不正确的形状张量 (B, H, W).

which is returning an incorrect tensor of shape (B, H, W).

推荐答案

应该可以通过像您所做的那样展平张量来做到这一点,但是在索引计算中必须考虑批量维度.为此,您必须制作一个与 xy 形状相同的附加虚拟批次索引张量,该张量始终包含当前批次的索引.这基本上是您的 numpy 示例中的 np.arange(B),而您的 TensorFlow 代码中缺少该代码.

It should be possible to do it by flattening the tensor as you've done, but the batch dimension has to be taken into account in the index calculation. In order to do this, you'll have to make an additional dummy batch index tensor with the same shape as x and y that always contains the index of the current batch. This is basically the np.arange(B) from your numpy example, which is missing from your TensorFlow code.

您还可以使用 tf.gather_nd 来简化一些事情,它为你做指数计算.

You can also simplify things a bit by using tf.gather_nd, which does the index calculations for you.

这是一个例子:

import numpy as np
import tensorflow as tf

# Example tensors
M = np.random.uniform(size=(3, 4, 5, 6))
x = np.random.randint(0, 5, size=(3, 4, 5))
y = np.random.randint(0, 4, size=(3, 4, 5))

def get_pixel_value(img, x, y):
    """
    Utility function that composes a new image, with pixels taken
    from the coordinates given in x and y.
    The shapes of x and y have to match.
    The batch order is preserved.
    """

    # We assume that x and y have the same shape.
    shape = tf.shape(x)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]

    # Create a tensor that indexes into the same batch.
    # This is needed for gather_nd to work.
    batch_idx = tf.range(0, batch_size)
    batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
    b = tf.tile(batch_idx, (1, height, width))

    indices = tf.pack([b, y, x], 3)
    return tf.gather_nd(img, indices)

s = tf.Session()
print(s.run(get_pixel_value(M, x, y)).shape)
# Should print (3, 4, 5, 6).
# We've composed a new image of the same size from randomly picked x and y
# coordinates of each original image.

这篇关于批量 4D 张量 Tensorflow 索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆