如何绘制Keras / Tensorflow子类化API模型? [英] How do I plot a Keras/Tensorflow subclassing API model?

查看:215
本文介绍了如何绘制Keras / Tensorflow子类化API模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用Keras子类化API制作了可以正常运行的模型。 model.summary()也可以正常工作。尝试使用 tf.keras.utils.plot_model()可视化模型的体系结构时,它将仅输出以下图像:

I made a model that runs correctly using the Keras Subclassing API. The model.summary() also works correctly. When trying to use tf.keras.utils.plot_model() to visualize my model's architecture, it will just output this image:

这几乎就像是Keras开发团队的笑话。这是完整的体系结构:

This almost feels like a joke from the Keras development team. This is the full architecture:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model

X, y = load_diabetes(return_X_y=True)

data = tf.data.Dataset.from_tensor_slices((X, y)).\
    shuffle(len(X)).\
    map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))

training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))

class NeuralNetwork(Model):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
        self.dense2 = Dense(32, activation='relu', name='Dense2')
        self.resha1 = Reshape((1, 32))
        self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
        self.dense3 = Dense(64, activation='relu', name='Dense3')
        self.gauss1 = GaussianDropout(5e-1)
        self.conca1 = Concatenate()
        self.dense4 = Dense(128, activation='relu', name='Dense4')
        self.dense5 = Dense(1, name='Dense5')

    def call(self, x, *args, **kwargs):
        x = self.dense1(x)
        x = self.dense2(x)
        a = self.resha1(x)
        a = self.gru1(a)
        b = self.dense3(x)
        b = self.gauss1(b)
        x = self.conca1([a, b])
        x = self.dense4(x)
        x = self.dense5(x)
        return x


skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()

model = tf.keras.utils.plot_model(model=skynet,
         show_shapes=True, to_file='/home/nicolas/Desktop/model.png')


推荐答案

之所以无法完成,是因为基本上模型子类(在TensorFlow中实现)在功能上受到限制与使用Functional / Sequential API(在TF术语中称为Graph网络)创建的模型相比。如果检查 plot_model 源代码,则会看到以下检查 model_to_dot 函数中(由<$ c $调用c> plot_model ):

It could not be done because basically model sub-classing, as it is implemented in TensorFlow, is limited in features and capabilities compared to the models created using Functional/Sequential API (which are called Graph networks in TF terminology). If you check the plot_model source code, you would see the following check in model_to_dot function (which is called by plot_model):

if not model._is_graph_network:
  node = pydot.Node(str(id(model)), label=model.name)
  dot.add_node(node)
  return dot

正如我提到的那样,子类模型不是图网络,因此将为这些模型绘制仅包含模型名称的节点(即,您观察到的相同)。

As I mentioned, the sub-classed models are not graph networks and therefore only a node containing the model name would be plotted for these models (i.e. the same thing you observed).

已经在 Github问题中进行了讨论,并且TensorFlow的开发者之一已得到确认通过提供以下参数,可以实现这种行为:

This has been already discussed in a Github issue and one of the developers of TensorFlow confirmed this behavior by giving the following argument:


@ omalleyt12评论:

@omalleyt12 commented:

是的,总的来说,我们不能对子类模型的结构做任何假设。如果您的模型可以看作是分层的图层,并且希望以这种方式可视化,则建议您查看Functional API

Yes in general we can't assume anything about the structure of a subclassed Model. If your Model can be though of as blocks of Layers and you wish to visualize it like that, we recommend you view the Functional API

这篇关于如何绘制Keras / Tensorflow子类化API模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆