Tensorflow 2:加载后不再能够跟踪子类模型的属性 [英] Tensorflow 2: No Longer Able to Track Attributes of a Subclassed Model After Loaded

查看:55
本文介绍了Tensorflow 2:加载后不再能够跟踪子类模型的属性的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这是我在 Tensorflow 2.5 中实现的子类模型:

Here is my implementation of a Subclassed Model in Tensorflow 2.5:

from tensorflow.keras import Model, Input
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.applications.densenet import preprocess_input
from tensorflow.keras.layers import Conv2D, Flatten, Dense
from tensorflow.random import uniform
from tensorflow.keras.models import load_model 

class Detector(Model):
    
    def __init__(self, num_classes=3, name="DenseNet201"):
        super(Detector, self).__init__(name=name)
        self.feature_extractor = DenseNet201(
            include_top=False,
            weights="imagenet",
        )
        self.feature_extractor.trainable = False
        self.flatten_layer = Flatten()
        self.prediction_layer = Dense(num_classes, activation=None)

    def call(self, inputs):
        x = preprocess_input(inputs)
        self.extracted_feature = self.feature_extractor(x, training=False)
        x = self.flatten_layer(self.extracted_feature)
        x = self.prediction_layer(x)
        return x

在测试我的代码时,我发现了一些让我感到困惑的东西.

While testing my code, I found something that really confused me.

detector = Detector()
print(detector.extracted_feature)

这给了我一个错误:AttributeError: 'Detector' object has no attribute 'extracted_feature',这是可以理解的,因为我从来没有调用过模型.调用模型后,Detector 对象现在具有属性 extracted_feature.所以下面的代码会执行没有任何错误:

This gives me an error: AttributeError: 'Detector' object has no attribute 'extracted_feature', which is understandable since I have never called the model in the first place. After calling the model, Detector object now has the attribute extracted_feature. So the following code will execute without any error:

image_tensor_1 = uniform(shape=(1, 600, 600, 3))
y_hat = detector(image_tensor_1)
print(detector.extracted_feature.shape)

但是,在尝试通过运行 detector.save("my_model") 保存模型并将模型加载回新变量后 new_detector = load_model("my_model").运行以下代码时出错:

However, after trying to save the model by running detector.save("my_model") and load the model back to a new variable new_detector = load_model("my_model"). I got an error running the code below:

image_tensor_2 = uniform(shape=(1, 600, 600, 3))
y_hat = new_detector(image_tensor_2)
print(new_detector.extracted_feature.shape)

AttributeError: 'Detector' 对象没有属性 'extracted_feature'.

self.extracted_feature 是我用来计算梯度的.我需要继续跟踪它,这样渐变就不会是 None.我应该怎么做才能访问属性 extracted_feature?

self.extracted_feature is what I use to calculate the gradient. I need to keep tracking it so the gradient will not be None. What should I do to access the attribute extracted_feature?

推荐答案

你可以这样做

    def call(self, inputs):
        x = preprocess_input(inputs)
        extracted_feature = self.feature_extractor(x, training=False)
        x = self.flatten_layer(extracted_feature)
        x = self.prediction_layer(x)
        return extracted_feature, x

检查

image_tensor_1 = uniform(shape=(1, 32, 32, 3))
detector = Detector()
ex_feat, y_hat = detector(image_tensor_1)
print(ex_feat.shape)
(1, 1, 1, 512)

保存并重新加载.

detector.save("my_model")
new_detector = load_model("my_model")

image_tensor_2 = uniform(shape=(1, 32, 32, 3))
ex_feat, y_hat = new_detector(image_tensor_2)
print(ex_feat.shape)
(1, 1, 1, 512)

仅供参考,如果您想从基础模型获得中间层输出,那么您可能需要在 __init__ 方法中以这种方式初始化您的基础模型.

FYI, if you want to get intermediate layer output from the base model, then you may need to init your base model in that way in the __init__ method.

这篇关于Tensorflow 2:加载后不再能够跟踪子类模型的属性的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆