如何使用plot_model可视化嵌套的`​​tf.keras.Model(SubClassed API)`GAN mdeol? [英] How to visualize nested `tf.keras.Model (SubClassed API)` GAN mdeol with plot_model?

查看:104
本文介绍了如何使用plot_model可视化嵌套的`​​tf.keras.Model(SubClassed API)`GAN mdeol?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

作为 keras子类实现的模型.通常无法使用 plot_model 可视化模型.有一种解决方法,如


选项1

我相信您可能会猜到这种方法.在 GANModel 模型中,我们非常明确地将 input 传递给这些子层的每个内部层( generator discriminator ).

  GANGAN(BaseModel)类:def __init __(自身,生成器,鉴别器):超级(GANModel,self).__ init __()self.generator =生成器self.discriminator =歧视者def调用(self,input_tensor,training = False,mask = None):x =输入张量对于self.generator.layers中的gen_lyr:print(gen_lyr)#检查x = gen_lyr(x)对于self.discriminator.layers中的disc_lyr:print(disc_lyr)#检查x = disc_lyr(x)返回x 

如果您现在进行绘图,您将得到

 #self.generator,self.discriminator的所有内部层< tensorflow.python.keras.layers.core.Dense对象位于0x7f2a472a3710>< tensorflow.python.keras.layers.core.Reshape对象位于0x7f2a461e8f50><位于0x7f2a44591f90的Tensorflow.python.keras.layers.convolutional.Conv2D对象>< 0x7f2a47317290处的tensorflow.python.keras.layers.convolutional.Conv2D对象>< tensorflow.python.keras.layers.core.Flatten对象位于0x7f2a47317ed0>< tensorflow.python.keras.layers.core.Dense对象位于0x7f2a57f42910> 


选项2

我认为这有点丑陋.首先,我们采用每个内部层,并使用它们构建一个 Sequential 模型.然后使用 .build 创建其输入层.BOOM.

  gan = GANModel(generator = g,discriminator = d)all_layer = []对于gan.layers中的图层:all_layer.extend(layer.layers)gan_plot = tf.keras.models.Sequential(all_layer)gan_plot.build((无,7,7,1))清单(all_layer)[< tensorflow.python.keras.layers.core.Dense at 0x7f2a461ab390> ;,< tensorflow.python.keras.layers.core.Reshape at 0x7f2a46156110> ;,< tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461fedd0> ;,< tensorflow.python.keras.layers.convolutional.Conv2D位于0x7f2a461500d0> ;,< tensorflow.python.keras.layers.core.Flatten at 0x7f2a4613ea10> ;,< tensorflow.python.keras.layers.core.Dense at 0x7f2a462cae10>] 

  tf.keras.utils.plot_model(gan_plot,expand_nested = True,show_shapes = True) 

Models implemented as subclasses of keras. Model can generally not be visualized with plot_model. There is a workaround as described here. However, it only applies to simple models. As soon as a model is enclosed by another model, the nestings will not be resolved.

I am looking for a way to resolve nested models implemented as subclasses of the keras. Model. As an example, I have created a minimal GAN model:

import keras
from keras import layers
from tensorflow.python.keras.utils.vis_utils import plot_model


class BaseModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super(BaseModel, self).__init__(*args, **kwargs)

    def call(self, inputs, training=None, mask=None):
        super(BaseModel, self).call(inputs=inputs, training=training, mask=mask)

    def get_config(self):
        super(BaseModel, self).get_config()

    def build_graph(self, raw_shape):
        """ Plot models that subclass `keras.Model`

        Adapted from https://stackoverflow.com/questions/61427583/how-do-i-plot-a-keras-tensorflow-subclassing-api-model

        :param raw_shape: Shape tuple not containing the batch_size
        :return:
        """
        x = keras.Input(shape=raw_shape)
        return keras.Model(inputs=[x], outputs=self.call(x))


class GANModel(BaseModel):
    def __init__(self, generator, discriminator):
        super(GANModel, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def call(self, input_tensor, training=False, mask=None):
        x = self.generator(input_tensor)
        x = self.discriminator(x)
        return x


class DiscriminatorModel(BaseModel):
    def __init__(self, name="Critic"):
        super(DiscriminatorModel, self).__init__(name=name)
        self.l1 = layers.Conv2D(64, 2, activation=layers.ReLU())
        self.flat = layers.Flatten()
        self.dense = layers.Dense(1)

    def call(self, inputs, training=False, mask=None):
        x = self.l1(inputs, training=training)
        x = self.flat(x)
        x = self.dense(x, training=training)
        return x


class GeneratorModel(BaseModel):
    def __init__(self, name="Generator"):
        super(GeneratorModel, self).__init__(name=name)
        self.dense = layers.Dense(128, activation=layers.ReLU())
        self.reshape = layers.Reshape((7, 7, 128))
        self.out = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")

    def call(self, inputs, training=False, mask=None):
        x = self.dense(inputs, training=training)
        x = self.reshape(x)
        x = self.out(x, training=training)
        return x


g = GeneratorModel()
d = DiscriminatorModel()

plot_model(g.build_graph((7, 7, 1)), to_file="generator_model.png",
           expand_nested=True, show_shapes=True)

gan = GANModel(generator=g, discriminator=d)
plot_model(gan.build_graph((7, 7, 1)), to_file="gan_model.png", 
           expand_nested=True, show_shapes=True)


Edit

Using the functional keras API I get the desired result (see here). The nested models are correctly resolved within the GAN model.

from keras import Model, layers, optimizers
from tensorflow.python.keras.utils.vis_utils import plot_model


def get_generator(input_dim):
    initial = layers.Input(shape=input_dim)

    x = layers.Dense(128, activation=layers.ReLU())(initial)
    x = layers.Reshape((7, 7, 128))(x)
    x = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")(x)

    return Model(inputs=initial, outputs=x, name="Generator")


def get_discriminator(input_dim):
    initial = layers.Input(shape=input_dim)

    x = layers.Conv2D(64, 2, activation=layers.ReLU())(initial)
    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)

    return Model(inputs=initial, outputs=x, name="Discriminator")

def get_gan(input_dim, latent_dim):
    initial = layers.Input(shape=input_dim)

    x = get_generator(input_dim)(initial)
    x = get_discriminator(latent_dim)(x)

    return Model(inputs=initial, outputs=x, name="GAN")



m = get_generator((7, 7, 1))
m.compile(optimizer=optimizers.Adam())

plot_model(m, expand_nested=True, show_shapes=True, to_file="generator_model_functional.png")

gan = get_gan((7, 7, 1), (7, 7, 1))
plot_model(gan, expand_nested=True, show_shapes=True, to_file="gan_model_functional.png")

解决方案

Whenever you pass each generator and discriminator to GANModel, they act like an encompassed child layer consisting of n times layers. So, if you plot only the generator model by the GANModel instances, it will show as follows (same goes to discriminator) unlike plots while using them separately.

The fact is while we pass data at this point using the call() method of GANModel, the input passes implicitly all internal layers (generator, discriminator) according to its design. Here I will show you two workaround for this to get your desired plot.


Option 1

I believe you probably guess the method. In the GANModel model, we will pass the input very explicitly to each internal layer of those child layers (generator, discriminator).

class GANModel(BaseModel):
    def __init__(self, generator, discriminator):
        super(GANModel, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def call(self, input_tensor, training=False, mask=None):
        x = input_tensor

        for gen_lyr in self.generator.layers:
            print(gen_lyr) # checking 
            x = gen_lyr(x)

        for disc_lyr in self.discriminator.layers:
            print(disc_lyr) # checking 
            x = disc_lyr(x)

        return x

If you plot now, you will get

# All Internal Layers of self.generator, self.discriminator
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a472a3710>
<tensorflow.python.keras.layers.core.Reshape object at 0x7f2a461e8f50>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a44591f90>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a47317290>
<tensorflow.python.keras.layers.core.Flatten object at 0x7f2a47317ed0>
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a57f42910>


Option 2

I think it's a bit ugly approach. First, we take each internal layer and build a Sequential model with them. Then use .build to create its input layer. BOOM.

gan = GANModel(generator=g, discriminator=d)

all_layer = []
for layer in gan.layers: 
    all_layer.extend(layer.layers)

gan_plot = tf.keras.models.Sequential(all_layer)
gan_plot.build((None,7,7,1))
list(all_layer)

[<tensorflow.python.keras.layers.core.Dense at 0x7f2a461ab390>,
 <tensorflow.python.keras.layers.core.Reshape at 0x7f2a46156110>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461fedd0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461500d0>,
 <tensorflow.python.keras.layers.core.Flatten at 0x7f2a4613ea10>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f2a462cae10>]

tf.keras.utils.plot_model(gan_plot, expand_nested=True, show_shapes=True)

这篇关于如何使用plot_model可视化嵌套的`​​tf.keras.Model(SubClassed API)`GAN mdeol?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆