使用tf.switch_case训练模型网络的不同分支 [英] Training different branches of model network with tf.switch_case
问题描述
我想创建一个神经网络,其中根据t_input训练网络的不同分支.因此,t_input可以为0或1,并取决于仅训练正确的分支:
I want to create a neural network in which different branches of the network are trained depending on the t_input. So the t_input can be either 0 or 1 and depending on that only the correct branch will be trained :
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
x = np.random.uniform(size=(10, 10))
t = np.random.binomial(100, 0.5)
t_input = Input(batch_shape=(1,), dtype='int32', name="t_input")
x_input = Input(shape=(x.shape[0]), name='x_input')
x = Dense(32)(x_input)
x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)
x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)
x1 = lambda: x1
x2 = lambda: x2
r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)
# r = tf.case([(tf.equal(t_input, 1), x1), (tf.equal(t_input, 0), x2)], default=x2, exclusive=True)
model = tf.keras.models.Model(inputs=t_input, outputs=r)
print(model.predict([1]))
但是,由于无法灵活使用KerasTensors,我无法完成这项工作:
However, I cannot make this work as it is not flexible enough to use KerasTensors :
Traceback (most recent call last):
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-59-92db0d55c181>", line 23, in <module>
r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 952, in __call__
input_list)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1091, in _functional_construction_call
inputs, input_masks, args, kwargs)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 822, in _keras_tensor_symbolic_call
return self._infer_output_signature(inputs, args, kwargs, input_masks)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 863, in _infer_output_signature
outputs = call_fn(inputs, *args, **kwargs)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\layers\core.py", line 917, in call
result = self.function(inputs, **kwargs)
File "<ipython-input-59-92db0d55c181>", line 23, in <lambda>
r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3616, in switch_case
return _indexed_case_helper(branch_fns, default, branch_index, name)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3326, in _indexed_case_helper
lower_using_switch_merge=lower_using_switch_merge)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\cond_v2.py", line 1040, in indexed_case
op_return_value=branch_index))
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 995, in func_graph_from_py_func
expand_composites=True)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\util\nest.py", line 659, in map_structure
structure[0], [func(*x) for x in entries],
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\util\nest.py", line 659, in <listcomp>
structure[0], [func(*x) for x in entries],
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 952, in convert
(str(python_func), type(x)))
TypeError: To be compatible with tf.eager.defun, Python functions must return zero or more Tensors; in compilation of <function <lambda> at 0x000001ED0876EAF8>, found return value of type <class 'function'>, which is not a Tensor.
推荐答案
通过将tf.switch_case更改为keras开关,并在其中输入两个单独的模型(您只在其中输入了其中一个模型),我的代码可以正常工作代码),请注意,我必须平铺您的 t_test
输入,因为它希望两个输入具有相同的批处理尺寸.我也不确定您是否需要np.random.binomial,因为这是从二项分布中采样的,并且几乎永远不会返回0.您应该查看 np.random.randint
并将其限制为0或1.
I got your code working by changing your tf.switch_case to a keras switch, and by inputting the two separate models in (you only input one of them in your code) Note that I had to tile your t_test
input because it expects the two inputs to have the same batch dimension. I am also not sure that you want np.random.binomial because this samples from the binomial distribution and will almost never return 0. You should probably look at np.random.randint
and limit it to values of 0 or 1.
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
import tensorflow.keras as K
import numpy as np
x_test = np.random.uniform(size=(10, 10))
t_test = np.array([np.random.binomial(100, 0.5)])
t_input = Input(shape=(1,), dtype=tf.int32, name="t_input")
x_input = Input(shape=(x_test.shape[1],), name='x_input')
x = Dense(32)(x_input)
x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)
x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)
r = K.backend.switch(t_input,x1,x2)
model = tf.keras.models.Model(inputs=[t_input,x_input], outputs=r)
print(model.predict([np.tile(t_test,10),x_test]))
这篇关于使用tf.switch_case训练模型网络的不同分支的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!