更新张量流中的张量切片 [英] Update slice of tensor in tensorflow

查看:40
本文介绍了更新张量流中的张量切片的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想更新一个 3 维张量的切片.按照 如何在 Tensorflow 中进行切片分配,我会做类似的事情

I want to update a slice of a tensor with 3 dimensions. Following How to do slice assignment in Tensorflow I would do something like

import tensorflow as tf

with tf.Session() as sess:
    init_val = tf.Variable(tf.zeros((2, 3, 3)))
    indices = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]])
    update = tf.scatter_nd_add(init_val, indices, tf.ones(4))

    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(update))

这可行,但由于我的实际问题更复杂,我想通过定义切片的开头和大小以某种方式自动生成索引集,例如如果您使用 tf.slice(...).你有什么想法?提前致谢!

This works, but since my actual problem is more complex I would like to generate the set of indices somehow automatically by defining the beginning and the size of the slice, such as if you would use tf.slice(...). Do you have any ideas? Thanks in advance!

我使用的是当前最新版本的 TensorFlow 1.12.

I am using TensorFlow 1.12, which is currently the most recent release.

推荐答案

tf.strided_slice 支持传递一个 var 参数来指示切片引用的变量,所以当你传递它时它会返回一个可赋值的对象(我不知道为什么他们不只是根据输入的类型来做,但无论如何).你可以这样做:

tf.strided_slice supports passing a var parameter to indicate the variable the slice refers to, so when you pass it it will return an assignable object (I'm not sure why they didn't just do that depending on the type of the input, but whatever). You can do something like this:

import tensorflow as tf
import numpy as np

var = tf.Variable(np.ones((3, 4), dtype=np.float32))
s = tf.strided_slice(var, [0, 2], [2, 3], var=var, name='var_slice')
s2 = s.assign([[2], [3]])
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(s2))

输出:

[[1. 1. 2. 1.]
 [1. 1. 3. 1.]
 [1. 1. 1. 1.]]

请注意,在 tf.strided_slice您提供开始和结束索引(不包括结束),与 tf.slice,在其中给出开始和大小.此外,就目前的代码而言,您必须为切片或分配操作提供一个名称值(我觉得这应该是一个错误,因为 API 的那部分几乎只在内部使用).

Note that in tf.strided_slice you give the begin and end indices (end not included), unlike intf.slice, where you give begin and size. Also, as the code currently stands, you have to provide a name value either in for the slice or the assign operation (I feel this should be a bug and happens because that part of the API is used almost exclusively internally).

这篇关于更新张量流中的张量切片的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆