Tensorflow tf.cond 评估两者 [英] Tensorflow tf.cond evaluating both pedicate

查看:37
本文介绍了Tensorflow tf.cond 评估两者的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

import tensorflow as tf
import numpy as np

isTrain = tf.placeholder(tf.bool)
user_input = tf.placeholder(tf.float32)

# ema = tf.train.ExponentialMovingAverage(decay=.5)

with tf.device('/cpu:0'):
    beta = tf.Variable(tf.ones([1]))

    batch_mean = beta.assign(user_input)
    ema = tf.train.ExponentialMovingAverage(decay=0.5)
    ema_apply_op = ema.apply([batch_mean])
    ema_mean = ema.average(batch_mean)

    def mean_var_with_update():
        with tf.control_dependencies([ema_apply_op]):
            return tf.identity(batch_mean)

    mean = tf.cond(isTrain,
        mean_var_with_update,
        lambda: (ema_mean))

# ======= End Here ==========
saver = tf.train.Saver()
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

u_input = [[2], [3], [4] ]
for u in u_input:
    aa = sess.run([mean], feed_dict={user_input:u, isTrain: True })
    print("Train", aa)

for u in u_input:
    aa = sess.run([ema_mean], feed_dict={user_input:u, isTrain: False })
    print("Test correct", aa)

for u in u_input:
    aa = sess.run([mean], feed_dict={user_input:u, isTrain: False })
    print("Test", aa)

此代码片段应计算整个训练阶段的 user_input 平均值和测试阶段的输出平均值.

This code snippet should calculate the mean of user_input across training stage and output mean during testing stage.

这是输出结果:

('Train', [array([ 2.], dtype=float32)])
('Train', [array([ 3.], dtype=float32)])
('Train', [array([ 4.], dtype=float32)])
('Test correct', [array([ 3.], dtype=float32)])
('Test correct', [array([ 3.], dtype=float32)])
('Test correct', [array([ 3.], dtype=float32)])
('Test', [array([ 2.5], dtype=float32)])
('Test', [array([ 2.75], dtype=float32)])
('Test', [array([ 3.375], dtype=float32)])

但是,即使 isTrain = False,在调用 sess.run([mean]) 时总是会评估 ema_mean.

However, ema_mean always get evaluated when calling sess.run([mean]) even if isTrain = False.

代码有错误吗?张量流版本是 0.7.1

Is there any mistake in the code ? tensorflow version is 0.7.1

推荐答案

我认为这与 在这里回答.条件语句中的 tf.control_dependencies 会将依赖添加到 tf.cond 本身.

I think that is the same as answered here. The tf.control_dependencies inside the conditionals will add the dependencies to the tf.cond itself.

所以尝试在 mean_var_with_update 函数内创建 ema_apply_op.

So try to create the ema_apply_op inside the mean_var_with_update function.

这篇关于Tensorflow tf.cond 评估两者的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆