Tensorflow 2:加载后不再能够跟踪子类模型的属性 [英] Tensorflow 2: No Longer Able to Track Attributes of a Subclassed Model After Loaded
问题描述
这是我在 Tensorflow 2.5 中实现的子类模型:
Here is my implementation of a Subclassed Model in Tensorflow 2.5:
from tensorflow.keras import Model, Input
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.applications.densenet import preprocess_input
from tensorflow.keras.layers import Conv2D, Flatten, Dense
from tensorflow.random import uniform
from tensorflow.keras.models import load_model
class Detector(Model):
def __init__(self, num_classes=3, name="DenseNet201"):
super(Detector, self).__init__(name=name)
self.feature_extractor = DenseNet201(
include_top=False,
weights="imagenet",
)
self.feature_extractor.trainable = False
self.flatten_layer = Flatten()
self.prediction_layer = Dense(num_classes, activation=None)
def call(self, inputs):
x = preprocess_input(inputs)
self.extracted_feature = self.feature_extractor(x, training=False)
x = self.flatten_layer(self.extracted_feature)
x = self.prediction_layer(x)
return x
在测试我的代码时,我发现了一些让我感到困惑的东西.
While testing my code, I found something that really confused me.
detector = Detector()
print(detector.extracted_feature)
这给了我一个错误:AttributeError: 'Detector' object has no attribute 'extracted_feature',这是可以理解的,因为我从来没有调用过模型.调用模型后,Detector
对象现在具有属性 extracted_feature
.所以下面的代码会执行没有任何错误:
This gives me an error: AttributeError: 'Detector' object has no attribute 'extracted_feature', which is understandable since I have never called the model in the first place. After calling the model, Detector
object now has the attribute extracted_feature
. So the following code will execute without any error:
image_tensor_1 = uniform(shape=(1, 600, 600, 3))
y_hat = detector(image_tensor_1)
print(detector.extracted_feature.shape)
但是,在尝试通过运行 detector.save("my_model")
保存模型并将模型加载回新变量后 new_detector = load_model("my_model")代码>.运行以下代码时出错:
However, after trying to save the model by running detector.save("my_model")
and load the model back to a new variable new_detector = load_model("my_model")
. I got an error running the code below:
image_tensor_2 = uniform(shape=(1, 600, 600, 3))
y_hat = new_detector(image_tensor_2)
print(new_detector.extracted_feature.shape)
AttributeError: 'Detector' 对象没有属性 'extracted_feature'.
self.extracted_feature
是我用来计算梯度的.我需要继续跟踪它,这样渐变就不会是 None
.我应该怎么做才能访问属性 extracted_feature
?
self.extracted_feature
is what I use to calculate the gradient. I need to keep tracking it so the gradient will not be None
. What should I do to access the attribute extracted_feature
?
推荐答案
你可以这样做
def call(self, inputs):
x = preprocess_input(inputs)
extracted_feature = self.feature_extractor(x, training=False)
x = self.flatten_layer(extracted_feature)
x = self.prediction_layer(x)
return extracted_feature, x
检查
image_tensor_1 = uniform(shape=(1, 32, 32, 3))
detector = Detector()
ex_feat, y_hat = detector(image_tensor_1)
print(ex_feat.shape)
(1, 1, 1, 512)
保存并重新加载.
detector.save("my_model")
new_detector = load_model("my_model")
image_tensor_2 = uniform(shape=(1, 32, 32, 3))
ex_feat, y_hat = new_detector(image_tensor_2)
print(ex_feat.shape)
(1, 1, 1, 512)
仅供参考,如果您想从基础模型获得中间层输出,那么您可能需要在 __init__
方法中以这种方式初始化您的基础模型.
FYI, if you want to get intermediate layer output from the base model, then you may need to init your base model in that way in the __init__
method.
这篇关于Tensorflow 2:加载后不再能够跟踪子类模型的属性的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!