使用K.switch()进行keras(tensorflow后端)条件赋值 [英] keras (tensorflow backend) conditional assignment with K.switch()
问题描述
我正在尝试实现类似的东西
I'm trying to implement something like
if np.max(subgrid) == np.min(subgrid):
middle_middle = cur_subgrid + 1
else:
middle_middle = cur_subgrid
由于只能在运行时确定条件,因此我使用Keras语法如下
Since the condition can only be determined at run-time, I'm using Keras syntax as following
middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
但是我遇到了这个错误:
But I'm getting this error:
<ipython-input-112-0504ce070e71> in col_loop(j, gray_map, mask_A)
56
57
---> 58 middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
59
60 print ('ml',middle_left.shape)
/nfs/isicvlnas01/share/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in switch(condition, then_expression, else_expression) 2561 The selected tensor. 2562 """
-> 2563 if condition.dtype != tf.bool: 2564 condition = tf.cast(condition, 'bool') 2565 if not callable(then_expression):
AttributeError: 'bool' object has no attribute 'dtype'
middle_middle
,cur_subgrid
和子网格都是NxN
张量.感谢您的帮助.
middle_middle
, cur_subgrid
, and subgrid are all NxN
tensors. Any help is appreciated.
推荐答案
我认为问题在于,使用K.max(subgrid) == K.min(subgrid)
您正在创建比较两个张量对象而不是的 python布尔值 > tensorflow布尔张量,其中包含两个输入张量的值的比较值.
I think the problem is that with K.max(subgrid) == K.min(subgrid)
you're creating a python boolean comparing two tensor objects, not a tensorflow boolean tensor containing the value of the comparison of the values of the two input tensors.
换句话说,您所写的内容将被评估为
In other words, what you have written will be evaluated as
K.switch(True, lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
代替
comparison = ... # Some tensor, that at runtime will contain True if min and max are the same, False otherwise.
K.switch(comparison , lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
因此,您需要使用 keras.backend. equal()代替==
:
So what you need to do is to use keras.backend.equal() instead of ==
:
K.switch(K.equal(K.max(subgrid),K.min(subgrid)), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
这篇关于使用K.switch()进行keras(tensorflow后端)条件赋值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!