了解Tensorflow中的while循环 [英] Understanding the while loop in Tensorflow

查看:227
本文介绍了了解Tensorflow中的while循环的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用用于Tensorflow的Python API .我试图在不使用Python循环的情况下实现下面给出的 Rosenbrock函数:

I am using the Python API for Tensorflow. I am trying to implement the Rosenbrock function given below without the use of a Python loop:

我当前的实现如下:

def rosenbrock(data_tensor):
    columns = tf.unstack(data_tensor)

    summation = 0
    for i in range(1, len(columns) - 1):
        first_term = tf.square(tf.subtract(columns[i + 1], tf.square(columns[i])))
        second_term = tf.square(tf.subtract(columns[i], 1.0))
        summation += tf.add(tf.multiply(100.0, first_term), second_term)

    return summation

我尝试在 tf.while_loop() 中实现求和;但是,当使用索引整数来保持与数据的分离时,我发现API有点不直观. 文档中给出的示例使用数据作为索引(反之亦然) ):

I have tried implementing the summation in a tf.while_loop(); however, I found the API somewhat unintuitive when it comes to using an index integer that is meant to remain separate from the data. The example given in the documentation uses the data as the index (or vice-versa):

i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])

推荐答案

可以使用tf.while_loop()和标准文档.

This can be achieved using the tf.while_loop() and standard tuples as per the second example in the documentation.

def rosenbrock(data_tensor):
    columns = tf.unstack(data_tensor)

    # Track both the loop index and summation in a tuple in the form (index, summation)
    index_summation = (tf.constant(1), tf.constant(0.0))

    # The loop condition, note the loop condition is 'i < n-1'
    def condition(index, summation):
        return tf.less(index, tf.subtract(tf.shape(columns)[0], 1))

    # The loop body, this will return a result tuple in the same form (index, summation)
    def body(index, summation):
        x_i = tf.gather(columns, index)
        x_ip1 = tf.gather(columns, tf.add(index, 1))

        first_term = tf.square(tf.subtract(x_ip1, tf.square(x_i)))
        second_term = tf.square(tf.subtract(x_i, 1.0))
        summand = tf.add(tf.multiply(100.0, first_term), second_term)

        return tf.add(index, 1), tf.add(summation, summand)

    # We do not care about the index value here, return only the summation
    return tf.while_loop(condition, body, index_summation)[1]

重要的是要注意,索引增量应该出现在循环主体中,类似于标准的while循环.在给定的解决方案中,它是body()函数返回的元组中的第一项.

It is important to note that the index increment should occur in the body of the loop similar to a standard while loop. In the solution given, it is the first item in the tuple returned by the body() function.

此外,循环条件函数必须为求和分配一个参数,尽管在此特定示例中未使用该参数.

Additionally, the loop condition function must allot a parameter for the summation although it is not used in this particular example.

这篇关于了解Tensorflow中的while循环的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆