如何使用plot_model可视化嵌套的`tf.keras.Model(SubClassed API)`GAN mdeol? [英] How to visualize nested `tf.keras.Model (SubClassed API)` GAN mdeol with plot_model?
问题描述
作为 keras子类实现的模型.通常无法使用
.有一种解决方法,如 plot_model
可视化模型
选项1
我相信您可能会猜到这种方法.在 GANModel
模型中,我们非常明确地将 input
传递给这些子层的每个内部层( generator
, discriminator
).
GANGAN(BaseModel)类:def __init __(自身,生成器,鉴别器):超级(GANModel,self).__ init __()self.generator =生成器self.discriminator =歧视者def调用(self,input_tensor,training = False,mask = None):x =输入张量对于self.generator.layers中的gen_lyr:print(gen_lyr)#检查x = gen_lyr(x)对于self.discriminator.layers中的disc_lyr:print(disc_lyr)#检查x = disc_lyr(x)返回x
如果您现在进行绘图,您将得到
#self.generator,self.discriminator的所有内部层< tensorflow.python.keras.layers.core.Dense对象位于0x7f2a472a3710>< tensorflow.python.keras.layers.core.Reshape对象位于0x7f2a461e8f50><位于0x7f2a44591f90的Tensorflow.python.keras.layers.convolutional.Conv2D对象>< 0x7f2a47317290处的tensorflow.python.keras.layers.convolutional.Conv2D对象>< tensorflow.python.keras.layers.core.Flatten对象位于0x7f2a47317ed0>< tensorflow.python.keras.layers.core.Dense对象位于0x7f2a57f42910>
选项2
我认为这有点丑陋.首先,我们采用每个内部层,并使用它们构建一个 Sequential
模型.然后使用 .build
创建其输入层.BOOM.
gan = GANModel(generator = g,discriminator = d)all_layer = []对于gan.layers中的图层:all_layer.extend(layer.layers)gan_plot = tf.keras.models.Sequential(all_layer)gan_plot.build((无,7,7,1))清单(all_layer)[< tensorflow.python.keras.layers.core.Dense at 0x7f2a461ab390> ;,< tensorflow.python.keras.layers.core.Reshape at 0x7f2a46156110> ;,< tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461fedd0> ;,< tensorflow.python.keras.layers.convolutional.Conv2D位于0x7f2a461500d0> ;,< tensorflow.python.keras.layers.core.Flatten at 0x7f2a4613ea10> ;,< tensorflow.python.keras.layers.core.Dense at 0x7f2a462cae10>]
tf.keras.utils.plot_model(gan_plot,expand_nested = True,show_shapes = True)
Models implemented as subclasses of keras. Model
can generally not be visualized with plot_model
. There is a workaround as described here. However, it only applies to simple models. As soon as a model is enclosed by another model, the nestings will not be resolved.
I am looking for a way to resolve nested models implemented as subclasses of the keras. Model
. As an example, I have created a minimal GAN model:
import keras
from keras import layers
from tensorflow.python.keras.utils.vis_utils import plot_model
class BaseModel(keras.Model):
def __init__(self, *args, **kwargs):
super(BaseModel, self).__init__(*args, **kwargs)
def call(self, inputs, training=None, mask=None):
super(BaseModel, self).call(inputs=inputs, training=training, mask=mask)
def get_config(self):
super(BaseModel, self).get_config()
def build_graph(self, raw_shape):
""" Plot models that subclass `keras.Model`
Adapted from https://stackoverflow.com/questions/61427583/how-do-i-plot-a-keras-tensorflow-subclassing-api-model
:param raw_shape: Shape tuple not containing the batch_size
:return:
"""
x = keras.Input(shape=raw_shape)
return keras.Model(inputs=[x], outputs=self.call(x))
class GANModel(BaseModel):
def __init__(self, generator, discriminator):
super(GANModel, self).__init__()
self.generator = generator
self.discriminator = discriminator
def call(self, input_tensor, training=False, mask=None):
x = self.generator(input_tensor)
x = self.discriminator(x)
return x
class DiscriminatorModel(BaseModel):
def __init__(self, name="Critic"):
super(DiscriminatorModel, self).__init__(name=name)
self.l1 = layers.Conv2D(64, 2, activation=layers.ReLU())
self.flat = layers.Flatten()
self.dense = layers.Dense(1)
def call(self, inputs, training=False, mask=None):
x = self.l1(inputs, training=training)
x = self.flat(x)
x = self.dense(x, training=training)
return x
class GeneratorModel(BaseModel):
def __init__(self, name="Generator"):
super(GeneratorModel, self).__init__(name=name)
self.dense = layers.Dense(128, activation=layers.ReLU())
self.reshape = layers.Reshape((7, 7, 128))
self.out = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")
def call(self, inputs, training=False, mask=None):
x = self.dense(inputs, training=training)
x = self.reshape(x)
x = self.out(x, training=training)
return x
g = GeneratorModel()
d = DiscriminatorModel()
plot_model(g.build_graph((7, 7, 1)), to_file="generator_model.png",
expand_nested=True, show_shapes=True)
gan = GANModel(generator=g, discriminator=d)
plot_model(gan.build_graph((7, 7, 1)), to_file="gan_model.png",
expand_nested=True, show_shapes=True)
Edit
Using the functional keras API I get the desired result (see here). The nested models are correctly resolved within the GAN model.
from keras import Model, layers, optimizers
from tensorflow.python.keras.utils.vis_utils import plot_model
def get_generator(input_dim):
initial = layers.Input(shape=input_dim)
x = layers.Dense(128, activation=layers.ReLU())(initial)
x = layers.Reshape((7, 7, 128))(x)
x = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")(x)
return Model(inputs=initial, outputs=x, name="Generator")
def get_discriminator(input_dim):
initial = layers.Input(shape=input_dim)
x = layers.Conv2D(64, 2, activation=layers.ReLU())(initial)
x = layers.Flatten()(x)
x = layers.Dense(1)(x)
return Model(inputs=initial, outputs=x, name="Discriminator")
def get_gan(input_dim, latent_dim):
initial = layers.Input(shape=input_dim)
x = get_generator(input_dim)(initial)
x = get_discriminator(latent_dim)(x)
return Model(inputs=initial, outputs=x, name="GAN")
m = get_generator((7, 7, 1))
m.compile(optimizer=optimizers.Adam())
plot_model(m, expand_nested=True, show_shapes=True, to_file="generator_model_functional.png")
gan = get_gan((7, 7, 1), (7, 7, 1))
plot_model(gan, expand_nested=True, show_shapes=True, to_file="gan_model_functional.png")
Whenever you pass each generator
and discriminator
to GANModel
, they act like an encompassed child layer consisting of n
times layers. So, if you plot only the generator
model by the GANModel
instances, it will show as follows (same goes to discriminator
) unlike plots while using them separately.
The fact is while we pass data at this point using the call()
method of GANModel
, the input passes implicitly all internal layers (generator
, discriminator
) according to its design. Here I will show you two workaround for this to get your desired plot.
Option 1
I believe you probably guess the method. In the GANModel
model, we will pass the input
very explicitly to each internal layer of those child layers (generator
, discriminator
).
class GANModel(BaseModel):
def __init__(self, generator, discriminator):
super(GANModel, self).__init__()
self.generator = generator
self.discriminator = discriminator
def call(self, input_tensor, training=False, mask=None):
x = input_tensor
for gen_lyr in self.generator.layers:
print(gen_lyr) # checking
x = gen_lyr(x)
for disc_lyr in self.discriminator.layers:
print(disc_lyr) # checking
x = disc_lyr(x)
return x
If you plot now, you will get
# All Internal Layers of self.generator, self.discriminator
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a472a3710>
<tensorflow.python.keras.layers.core.Reshape object at 0x7f2a461e8f50>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a44591f90>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a47317290>
<tensorflow.python.keras.layers.core.Flatten object at 0x7f2a47317ed0>
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a57f42910>
Option 2
I think it's a bit ugly approach. First, we take each internal layer and build a Sequential
model with them. Then use .build
to create its input layer. BOOM.
gan = GANModel(generator=g, discriminator=d)
all_layer = []
for layer in gan.layers:
all_layer.extend(layer.layers)
gan_plot = tf.keras.models.Sequential(all_layer)
gan_plot.build((None,7,7,1))
list(all_layer)
[<tensorflow.python.keras.layers.core.Dense at 0x7f2a461ab390>,
<tensorflow.python.keras.layers.core.Reshape at 0x7f2a46156110>,
<tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461fedd0>,
<tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461500d0>,
<tensorflow.python.keras.layers.core.Flatten at 0x7f2a4613ea10>,
<tensorflow.python.keras.layers.core.Dense at 0x7f2a462cae10>]
tf.keras.utils.plot_model(gan_plot, expand_nested=True, show_shapes=True)
这篇关于如何使用plot_model可视化嵌套的`tf.keras.Model(SubClassed API)`GAN mdeol?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!