使用K.switch()进行keras(tensorflow后端)条件赋值 [英] keras (tensorflow backend) conditional assignment with K.switch()

查看:443
本文介绍了使用K.switch()进行keras(tensorflow后端)条件赋值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试实现类似的东西

I'm trying to implement something like

if np.max(subgrid) == np.min(subgrid):
    middle_middle = cur_subgrid + 1
else:
    middle_middle = cur_subgrid

由于只能在运行时确定条件,因此我使用Keras语法如下

Since the condition can only be determined at run-time, I'm using Keras syntax as following

middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

但是我遇到了这个错误:

But I'm getting this error:

<ipython-input-112-0504ce070e71> in col_loop(j, gray_map, mask_A)
     56 
     57 
---> 58             middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
     59 
     60             print ('ml',middle_left.shape)

/nfs/isicvlnas01/share/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in switch(condition, then_expression, else_expression)    2561         The selected tensor.    2562     """
-> 2563     if condition.dtype != tf.bool:    2564         condition = tf.cast(condition, 'bool')    2565     if not callable(then_expression):

AttributeError: 'bool' object has no attribute 'dtype'

middle_middlecur_subgrid和子网格都是NxN张量.感谢您的帮助.

middle_middle, cur_subgrid, and subgrid are all NxN tensors. Any help is appreciated.

推荐答案

我认为问题在于,使用K.max(subgrid) == K.min(subgrid)您正在创建比较两个张量对象而不是 python布尔值 > tensorflow布尔张量,其中包含两个输入张量的的比较值.

I think the problem is that with K.max(subgrid) == K.min(subgrid) you're creating a python boolean comparing two tensor objects, not a tensorflow boolean tensor containing the value of the comparison of the values of the two input tensors.

换句话说,您所写的内容将被评估为

In other words, what you have written will be evaluated as

K.switch(True, lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

代替

comparison = ... # Some tensor, that at runtime will contain True if min and max are the same, False otherwise. 
K.switch(comparison , lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

因此,您需要使用 keras.backend. equal()代替==:

So what you need to do is to use keras.backend.equal() instead of ==:

K.switch(K.equal(K.max(subgrid),K.min(subgrid)), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

这篇关于使用K.switch()进行keras(tensorflow后端)条件赋值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆