如何在 theano 中保存/序列化经过训练的模型? [英] How to save / serialize a trained model in theano?

查看:46
本文介绍了如何在 theano 中保存/序列化经过训练的模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我按照加载和保存中的说明保存了模型.

# 保存训练好的模型f = file('models/simple_model.save', 'wb')cPickle.dump(ca, f, 协议=cPickle.HIGHEST_PROTOCOL)f.close()

ca 是经过训练的自动编码器.它是 cA 类的一个实例.从我构建和保存模型的脚本中,我可以毫无问题地调用 ca.get_reconstructed_input(...)ca.get_hidden_​​values(...).

在不同的脚本中,我尝试加载经过训练的模型.

# 加载训练好的模型model_file = file('models/simple_model.save', 'rb')ca = cPickle.load(model_file)model_file.close()

我收到以下错误.

<块引用>

ca = cPickle.load(model_file)

AttributeError: 'module' 对象没有属性 'cA'

解决方案

执行 unpickling 的脚本需要知道所有 pickled 对象的类定义.在其他 StackOverflow 问题中有更多关于此的内容(例如 AttributeError: 'module' object has没有属性新人").

只要您正确导入cA,您的代码就是正确的.鉴于您遇到的错误可能并非如此.确保您使用的是 from cA import cA 而不仅仅是 import cA.

或者,您的模型由其参数定义,因此您可以只腌制参数值).这可以通过两种方式完成,具体取决于您的观点.

  1. 保存 Theano 共享变量.这里我们假设 ca.params 是 Theano 共享变量实例的常规 Python 列表.

    cPickle.dump(ca.params, f, 协议=cPickle.HIGHEST_PROTOCOL)

  2. 保存存储在 Theano 共享变量中的 numpy 数组.

    cPickle.dump([param.get_value() for param in ca.params], f, protocol=cPickle.HIGHEST_PROTOCOL)

当您想要加载模型时,您需要重新初始化参数.例如,创建 cA 类的新实例,然后

ca.params = cPickle.load(f)ca.W, ca.b, ca.b_prime = ca.params

ca.params = [theano.shared(param) for param in cPickle.load(f)]ca.W, ca.b, ca.b_prime = ca.params

请注意,您需要同时设置 params 字段和单独的参数字段.

I saved the model as documented on loading and saving.

# saving trained model
f = file('models/simple_model.save', 'wb')
cPickle.dump(ca, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()

ca is a trained auto-encoder. It's a instance of class cA. From the script in which I build and save the model I can call ca.get_reconstructed_input(...) and ca.get_hidden_values(...) without any problem.

In a different script I try to load the trained model.

# loading the trained model
model_file = file('models/simple_model.save', 'rb')
ca = cPickle.load(model_file)
model_file.close()

I receive the following error.

ca = cPickle.load(model_file)

AttributeError: 'module' object has no attribute 'cA'

解决方案

All the class definitions of the pickled objects need to be known by the script that does the unpickling. There is more on this in other StackOverflow questions (e.g. AttributeError: 'module' object has no attribute 'newperson').

Your code is correct as long as you properly import cA. Given the error you're getting it may not be the case. Make sure you're using from cA import cA and not just import cA.

Alternatively, your model is defined by its parameters so you could instead just pickle the parameter values). This could be done in two ways depending on what you point of view.

  1. Save the Theano shared variables. Here we assume that ca.params is a regular Python list of Theano shared variable instances.

    cPickle.dump(ca.params, f, protocol=cPickle.HIGHEST_PROTOCOL)
    

  2. Save the numpy arrays stored inside the Theano shared variables.

    cPickle.dump([param.get_value() for param in ca.params], f, protocol=cPickle.HIGHEST_PROTOCOL)
    

When you want to load the model you'll need to reinitialize the parameters. For example, create a new instance of the cA class then either

ca.params = cPickle.load(f)
ca.W, ca.b, ca.b_prime = ca.params

or

ca.params = [theano.shared(param) for param in cPickle.load(f)]
ca.W, ca.b, ca.b_prime = ca.params

Note that you need to set both the params field and the separate parameters fields.

这篇关于如何在 theano 中保存/序列化经过训练的模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆