Tensorflow Probability:保存和加载模型 [英] Tensorflow Probability: saving and loading model

查看:63
本文介绍了Tensorflow Probability:保存和加载模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用 TensorFlow 概率拟合模型,例如:

I am trying to fit a model with TensorFlow probability, for example:

input = Input(shape=(32,32,32,3))
x = tfp.layers.Convolution3DReparameterization(
        64, kernel_size=5, padding='SAME', activation=tf.nn.relu,
        data_format = 'channels_first')(input)
x = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2),
                                 strides=(2, 2, 2),
                                 padding='SAME')(x)
x = tf.keras.layers.Flatten()(x)
output = tfp.layers.DenseFlipout(10)(x)

model3 = Model(input, output)
model3.save('tf_test_model3.h5')

当我将模型加载为 model3 = load_model('tf_test_model3.h5') 时,出现以下错误:

When I load the model as model3 = load_model('tf_test_model3.h5'), I get the following error:

ValueError: Unknown layer: Conv3DReparameterization. Please ensure this object is passed to the `custom_objects` argument. 

当我将它传递给 custom_objects 时:

When I pass it to custom_objects as:

custom_objects= {'Conv3DReparameterization': tfp.layers.Convolution3DReparameterization}
model3 = load_model('tf_test_model3.h5', custom_objects=custom_objects)

我收到以下错误:TypeError: 'str' 对象不可调用

我做错了什么?我该如何解决这个问题?

What am I doing wrong? How can I fix this?

推荐答案

您代码中的错误是 data_format = channels_first

您给出的输入形状为 (32,32,32,3),它是 channel_last 类型的数据.提供编辑过的代码

You are giving input shape of (32,32,32,3) which is channel_last type data. providing edited code

import keras as tk
import tensorflow as tf
import tensorflow_probability as tfp
input = tf.keras.Input(shape=(32,32,32,3))
x = tfp.layers.Convolution3DReparameterization(
        64, kernel_size=5, padding='SAME', activation=tf.nn.relu,
        data_format = 'channels_last')(input) #changed  from channels_first
x = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2),
                                 strides=(2, 2, 2),
                                 padding='SAME')(x)
x = tf.keras.layers.Flatten()(x)
output = tfp.layers.DenseFlipout(10)(x)

model3 = tk.Model(input, output)
model3.save('tf_test_model3.h5')

这篇关于Tensorflow Probability:保存和加载模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆