如何在theano中保存/序列化受过训练的模型? [英] How to save / serialize a trained model in theano?
问题描述
我按照加载和保存中的文档进行保存.
I saved the model as documented on loading and saving.
# saving trained model
f = file('models/simple_model.save', 'wb')
cPickle.dump(ca, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()
ca
是训练有素的自动编码器.它是类 cA
的实例.从构建和保存模型的脚本中,我可以毫无问题地调用ca.get_reconstructed_input(...)
和ca.get_hidden_values(...)
.
ca
is a trained auto-encoder. It's a instance of class cA
. From the script in which I build and save the model I can call ca.get_reconstructed_input(...)
and ca.get_hidden_values(...)
without any problem.
在另一个脚本中,我尝试加载经过训练的模型.
In a different script I try to load the trained model.
# loading the trained model
model_file = file('models/simple_model.save', 'rb')
ca = cPickle.load(model_file)
model_file.close()
我收到以下错误.
ca = cPickle.load(model_file)
AttributeError:模块"对象没有属性"cA"
AttributeError: 'module' object has no attribute 'cA'
推荐答案
腌制对象的所有类定义都需要进行解腌的脚本知道.其他StackOverflow问题(例如 AttributeError:AttributeError:模块"对象具有没有属性新人" ).
All the class definitions of the pickled objects need to be known by the script that does the unpickling. There is more on this in other StackOverflow questions (e.g. AttributeError: 'module' object has no attribute 'newperson').
您的代码正确无误,只要您正确导入cA
.鉴于错误,您可能并非如此.确保您使用的是from cA import cA
而不是import cA
.
Your code is correct as long as you properly import cA
. Given the error you're getting it may not be the case. Make sure you're using from cA import cA
and not just import cA
.
或者,您的模型是由其参数定义的,因此您可以选择仅腌制参数值).可以通过两种方式完成此操作,具体取决于您的观点.
Alternatively, your model is defined by its parameters so you could instead just pickle the parameter values). This could be done in two ways depending on what you point of view.
-
保存Theano共享变量.在这里,我们假设
ca.params
是Theano共享变量实例的常规Python列表.
Save the Theano shared variables. Here we assume that
ca.params
is a regular Python list of Theano shared variable instances.
cPickle.dump(ca.params, f, protocol=cPickle.HIGHEST_PROTOCOL)
保存Theano共享变量中存储的numpy数组.
Save the numpy arrays stored inside the Theano shared variables.
cPickle.dump([param.get_value() for param in ca.params], f, protocol=cPickle.HIGHEST_PROTOCOL)
要加载模型时,需要重新初始化参数.例如,创建cA
类的新实例,然后选择
When you want to load the model you'll need to reinitialize the parameters. For example, create a new instance of the cA
class then either
ca.params = cPickle.load(f)
ca.W, ca.b, ca.b_prime = ca.params
或
ca.params = [theano.shared(param) for param in cPickle.load(f)]
ca.W, ca.b, ca.b_prime = ca.params
请注意,您需要同时设置params
字段和单独的参数字段.
Note that you need to set both the params
field and the separate parameters fields.
这篇关于如何在theano中保存/序列化受过训练的模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!