TensorFlow 创建动态形状变量 [英] TensorFlow create dynamic shape variable

查看:25
本文介绍了TensorFlow 创建动态形状变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我需要创建一个只有在执行时才知道的形状的 tf.Variable.

I need to create a tf.Variable with shape which is known only at the execution time.

我将代码简化为以下要点.我需要找到大于 4 的占位符数字,并且在生成的张量中需要 scatter_update 将第二项设为 24 常量.>

I simplified my code to the following gist. I need to find in placeholder numbers which is greater than 4 and in the resultant tensor need to scatter_update the second item to 24 constant.

import tensorflow as tf

def get_variable(my_variable):
    greater_than = tf.greater(my_variable, tf.constant(4))
    result = tf.boolean_mask(my_variable, greater_than)
    # result = tf.Variable(tf.zeros(tf.shape(result)), trainable=False, expected_shape=tf.shape(result), validate_shape=False)   # doesn't work either
    result = tf.get_variable("my_var", shape=tf.shape(my_variable), dtype=tf.int32)
    result = tf.scatter_update(result, [1], 24)
    return result

input = tf.placeholder(dtype=tf.int32, shape=[5])
    created_variable = get_variable(input)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result = sess.run(created_variable, feed_dict={input: [2, 7, 4, 6, 9]})
    print(result)

我发现 很少 问题,但他们没有答案,也没有帮助我.

I found few questions but they have no answers and didn't help me.

推荐答案

我遇到了同样的问题,偶然发现了相同的未回答的问题,并设法拼凑了一个解决方案,用于在创建图形时创建具有动态形状的变量时间.请注意,必须在 tf.Session.run(...) 之前或第一次执行时定义形状.

I had the same problem, stumbled upon the same unanswered questions and managed to piece together a solution for creating a variable with a dynamic shape at graph creation time. Note that the shape has to be defined before, or with the first execution of tf.Session.run(...).

import tensorflow as tf

def get_variable(my_variable):
    greater_than = tf.greater(my_variable, tf.constant(4))
    result = tf.boolean_mask(my_variable, greater_than)
    zerofill = tf.fill(tf.shape(my_variable), tf.constant(0, dtype=tf.int32))
    # Initialize
    result = tf.get_variable(
        "my_var", shape=None, validate_shape=False, dtype=tf.int32, initializer=zerofill
    )
    result = tf.scatter_update(result, [1], 24)
    return result

input = tf.placeholder(dtype=tf.int32, shape=[5])
created_variable = get_variable(input)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result = sess.run(created_variable, feed_dict={input: [2, 7, 4, 6, 9]})
    print(result)

诀窍是用 shape=Nonevalidate_shape=False 创建一个 tf.Variable 并交出一个 tf.Variable .未知形状的张量作为初始化器.

The trick is to create a tf.Variable with shape=None, validate_shape=False and hand over a tf.Tensor with unknown shape as initializer.

这篇关于TensorFlow 创建动态形状变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆