我可以在keras定制模型中使用字典吗? [英] Can I use dictionary in keras customized model?

查看:50
本文介绍了我可以在keras定制模型中使用字典吗?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我最近阅读了一篇有关UNet ++的论文,我想用tensorflow-2.0和keras定制模型来实现这种结构.由于结构是如此复杂,我决定通过字典来管理keras图层.在训练中一切都进行得很好,但是在保存模型时发生了错误.这是显示错误的最小代码:

I recently read a paper about UNet++,and I want to implement this structure with tensorflow-2.0 and keras customized model. As the structure is so complicated, I decided to manage the keras layers by a dictionary. Everything went well in training, but an error occurred while saving the model. Here is a minimum code to show the error:

class DicModel(tf.keras.Model):
    def __init__(self):
        super(DicModel, self).__init__(name='SequenceEECNN')
        self.c = {}
        self.c[0] = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3,activation='relu',padding='same'),
            tf.keras.layers.BatchNormalization()]
        )
        self.c[1] = tf.keras.layers.Conv2D(3,3,activation='softmax',padding='same')
    def call(self,images):
        x = self.c[0](images)
        x = self.c[1](x)
        return x

X_train,y_train = load_data()
X_test,y_test = load_data()

class_weight.compute_class_weight('balanced',np.ravel(np.unique(y_train)),np.ravel(y_train))

model = DicModel()
model_name = 'test'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs/'+model_name+'/')
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=100,mode='min')

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['accuracy'])

results = model.fit(X_train,y_train,batch_size=4,epochs=5,validation_data=(X_test,y_test),
                    callbacks=[tensorboard_callback,early_stop_callback],
                    class_weight=[0.2,2.0,100.0])

model.save_weights('model/'+model_name,save_format='tf')

错误信息是:

Traceback (most recent call last):

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/learn_tf2/test_model.py", line 61, in \<module>

    model.save_weights('model/'+model_name,save_format='tf')

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1328, in save_weights

    self.\_trackable_saver.save(filepath, session=session)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1106, in save

    file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1046, in \_save_cached_when_graph_building

    object_graph_tensor=object_graph_tensor)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1014, in \_gather_saveables

    feed_additions) = self.\_graph_view.serialize_object_graph()

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 379, in serialize_object_graph

    trackable_objects, path_to_root = self.\_breadth_first_traversal()

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 199, in \_breadth_first_traversal

    for name, dependency in self.list_dependencies(current_trackable):

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 159, in list_dependencies

    return obj.\_checkpoint_dependencies

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 690, in \_\_getattribute\_\_

    return object.\_\_getattribute\_\_(self, name)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 732, in \_checkpoint_dependencies

    "ignored." % (self,))

ValueError: Unable to save the object {0: \<tensorflow.python.keras.engine.sequential.Sequential object at 0x7fb5c6c36588>, 1: \<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb5c6c36630>} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.



If you don't need this dictionary checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency object; it will be automatically un-wrapped and subsequently ignored.

tf.contrib.checkpoint.NoDependency似乎已从Tensorflow-2.0中删除( https://medium.com/tensorflow/whats-coming-in-tensorflow-2-0-d3663832e9b8 ).如何解决此问题?还是我应该放弃在定制的Keras模型中使用字典.谢谢您的时间和帮助!

The tf.contrib.checkpoint.NoDependency seems has been removed from Tensorflow-2.0 (https://medium.com/tensorflow/whats-coming-in-tensorflow-2-0-d3663832e9b8). How can I fix this issue? Or should I just give up using dictionary in customized Keras Model. Thank you for your time and helps!

推荐答案

该异常消息在Tensorflow 2.0中不正确,并且在2.2中已得到解决

The exception message was incorrect in Tensorflow 2.0 and has been fixed in 2.2

您可以通过这样包装 c 属性来避免此问题

You can avoid the problem by wrapping the c attribute like this

from tensorflow.python.training.tracking.data_structures import NoDependency
self.c = NoDependency({})

有关更多详细信息,请参见此问题.

For more details check this issue.

这篇关于我可以在keras定制模型中使用字典吗?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆