在tensorflow-r1.2中正确使用`tf.scatter_nd` [英] Proper usage of `tf.scatter_nd` in tensorflow-r1.2

查看:67
本文介绍了在tensorflow-r1.2中正确使用`tf.scatter_nd`的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

给出形状为 [batch_size,sequence_len] indices ,形状为 [batch_size,sequence_len,sampled_size] updates >,形状为 [batch_size,sequence_len,vocab_size] to_shape ,其中 vocab_size >> sampled_size 喜欢使用 tf.scatter updates 映射到具有 to_shape 的巨大张量,例如, to_shape [bs,index [bs,sz]] =更新[bs,sz] .也就是说,我想逐行将 updates 映射到 to_shape .请注意, sequence_len sampled_size 是标量张量,而其他是固定的.我尝试执行以下操作:

Given indices with shape [batch_size, sequence_len], updates with shape [batch_size, sequence_len, sampled_size], to_shape with shape [batch_size, sequence_len, vocab_size], where vocab_size >> sampled_size, I'd like to use tf.scatter to map the updates to a huge tensor with to_shape, such that to_shape[bs, indices[bs, sz]] = updates[bs, sz]. That is, I'd like to map the updates to to_shape row by row. Please note that sequence_len and sampled_size are scalar tensors, while others are fixed. I tried to do the following:

new_tensor = tf.scatter_nd(tf.expand_dims(indices, axis=2), updates, to_shape)

但是我得到一个错误:

ValueError: The inner 2 dimension of output.shape=[?,?,?] must match the inner 1 dimension of updates.shape=[80,50,?]: Shapes must be equal rank, but are 2 and 1 for .... with input shapes: [80, 50, 1], [80, 50,?], [3]

能否请您告诉我如何正确使用 scatter_nd ?预先感谢!

Could you please tell me how to use scatter_nd properly? Thanks in advance!

推荐答案

所以假设您有:

  • 张量更新,形状为 [batch_size,sequence_len,sampled_size] .
  • 张量 indices ,形状为 [batch_size,sequence_len,sampled_size] .
  • A tensor updates with shape [batch_size, sequence_len, sampled_size].
  • A tensor indices with shape [batch_size, sequence_len, sampled_size].

然后您这样做:

import tensorflow as tf

# Create updates and indices...

# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batch_size),
                     tf.range(sequence_len), indexing="ij")
i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, sampled_size])
i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, sampled_size])
# Create final indices
idx = tf.stack([i1, i2, indices], axis=-1)
# Output shape
to_shape = [batch_size, sequence_len, vocab_size]
# Get scattered tensor
output = tf.scatter_nd(idx, updates, to_shape)

tf.scatter_nd indices 张量, updates 张量和某种形状. updates 是原始张量,并且形状只是所需的输出形状,因此 [batch_size,sequence_len,vocab_size] .现在, indices 更复杂了.由于您的输出具有3个维度(第3级),因此对于 updates 中的每个元素,您需要3个索引来确定每个元素将放置在输出中的位置.因此, indices 参数的形状应与 updates 相同,并附加一个尺寸为3的尺寸.在这种情况下,我们希望第一个尺寸与尺寸相同.但是我们仍然必须指定3个索引.因此,我们使用 tf.meshgrid 来生成我们需要的索引,然后将它们沿第三个维度平铺( updates 的最后一个维度中每个元素向量的第一个索引和第二个索引是相同的).最后,我们将这些索引与以前创建的映射索引堆叠在一起,并获得完整的3维索引.

tf.scatter_nd takes an indices tensor, an updates tensor and some shape. updates is the original tensor, and the shape is just the desired output shape, so [batch_size, sequence_len, vocab_size]. Now, indices is more complicated. Since your output has 3 dimensions (rank 3), for each of the elements in updates you need 3 indices to determine where in the output each element is going to be placed. So the shape of the indices parameter should be the same as updates with an additional dimension of size 3. In this case, we want the first to dimensions to be the same, but we still have to specify the 3 indices. So we use tf.meshgrid to generate the indices that we need and we tile them along the third dimension (the first and second index for each element vector in the last dimension of updates is the same). Finally, we stack these indices with the previously created mapping indices and we have our full 3-dimensional indices.

这篇关于在tensorflow-r1.2中正确使用`tf.scatter_nd`的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆