Tensorflow 2 - tensor_scatter_nd_update 中的“索引深度"是什么? [英] Tensorflow 2 - what is 'index depth' in tensor_scatter_nd_update?

查看:38
本文介绍了Tensorflow 2 - tensor_scatter_nd_update 中的“索引深度"是什么?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

请解释什么是

tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2索引 = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2更新 = [5, 10] # num_updates == 2打印(tf.tensor_scatter_nd_update(张量,索引,更新))

解决方案

对于索引索引深度是大小或长度索引向量.例如:

indicesA = [[1], [3], [4], [7]] # 1个元素的索引向量:index_depth = 1indexB = [[0, 1], [2, 0]] # 2 个元素的索引向量:index_depth = 2


索引的原因是 2D 保存两个信息,一个是更新的长度 (num_updates)>索引向量的长度.需要完成两件事:

  • indicesindex depth必须等于input张量的rank
  • updates 的长度必须等于 indices
  • length

所以,在示例代码中

# tf.rank(tensor) == 1张量 = [0, 0, 0, 0, 0, 0, 0, 0]# num_updates == 4, index_depth == 1 |tf.rank(indices).numpy() == 2指数 = [[1], [3], [4], [7]]# num_updates == 4 |tf.rank(output).numpy() == 1更新 = [9, 10, 11, 12]输出 = tf.tensor_scatter_nd_update(张量,索引,更新)tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32)

还有

# tf.rank(tensor) == 2张量 = [[1, 1], [1, 1], [1, 1]]# num_updates == 2, index_depth == 2 |tf.rank(indices).numpy() == 2指数 = [[0, 1], [2, 0]]# num_updates == 2 |tf.rank(output).numpy() == 2更新 = [5, 10]输出 = tf.tensor_scatter_nd_update(张量,索引,更新)tf.张量([[1 5][ 1 1][10 1]], 形状=(3, 2), dtype=int32)num_updates, index_depth = tf.convert_to_tensor(indices).shape.as_list()[num_updates, index_depth][2, 2]

Please explain what is index depth of tf.tensor_scatter_nd_update.

tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)

Why indices is 2D for 1D tensor?

indices has at least two axes, the last axis is the depth of the index vectors. For a higher rank input tensor scalar updates can be inserted by using an index_depth that matches tf.rank(tensor):

tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]]       # num_updates == 4, index_depth == 1   # <--- what is depth and why 2D for 1D tensor?
updates = [9, 10, 11, 12]            # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))

tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
updates = [5, 10]                    # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))

解决方案

For indices, the index depth is the size or length of the index vectors. For example:

indicesA = [[1], [3], [4], [7]] # index vector with 1 element: index_depth = 1
indicesB = [[0, 1], [2, 0]]     # index vector with 2 element: index_depth = 2


The reason for indices is 2D is to hold two information, one is the length of the updates (num_updates) and the length of the index vector. Two things need to be fulfilled:

  • The index depth of indices must equal the rank of the input tensor
  • The length of updates must equal the length of the indices

So, in the example code

# tf.rank(tensor) == 1
tensor = [0, 0, 0, 0, 0, 0, 0, 0]    

# num_updates == 4, index_depth == 1 | tf.rank(indices).numpy() == 2 
indices = [[1], [3], [4], [7]]    

# num_updates == 4 | tf.rank(output).numpy() == 1  
updates = [9, 10, 11, 12]        

output = tf.tensor_scatter_nd_update(tensor, indices, updates)
tf.Tensor([ 0  9  0 10 11  0  0 12], shape=(8,), dtype=int32)

Also

# tf.rank(tensor) == 2
tensor = [[1, 1], [1, 1], [1, 1]]    

 # num_updates == 2, index_depth == 2 | tf.rank(indices).numpy() == 2
indices = [[0, 1], [2, 0]]          

# num_updates == 2 | tf.rank(output).numpy() == 2
updates = [5, 10]       
             
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
tf.Tensor(
[[ 1  5]
 [ 1  1]
 [10  1]], shape=(3, 2), dtype=int32)

num_updates, index_depth = tf.convert_to_tensor(indices).shape.as_list()
[num_updates, index_depth]
[2, 2]

这篇关于Tensorflow 2 - tensor_scatter_nd_update 中的“索引深度"是什么?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆