如何在theano中保存/序列化受过训练的模型? [英] How to save / serialize a trained model in theano?

查看:194
本文介绍了如何在theano中保存/序列化受过训练的模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我按照加载和保存中的文档进行保存.

I saved the model as documented on loading and saving.

# saving trained model
f = file('models/simple_model.save', 'wb')
cPickle.dump(ca, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()

ca是训练有素的自动编码器.它是类 cA 的实例.从构建和保存模型的脚本中,我可以毫无问题地调用ca.get_reconstructed_input(...)ca.get_hidden_values(...).

ca is a trained auto-encoder. It's a instance of class cA. From the script in which I build and save the model I can call ca.get_reconstructed_input(...) and ca.get_hidden_values(...) without any problem.

在另一个脚本中,我尝试加载经过训练的模型.

In a different script I try to load the trained model.

# loading the trained model
model_file = file('models/simple_model.save', 'rb')
ca = cPickle.load(model_file)
model_file.close()

我收到以下错误.

ca = cPickle.load(model_file)

AttributeError:模块"对象没有属性"cA"

AttributeError: 'module' object has no attribute 'cA'

推荐答案

腌制对象的所有类定义都需要进行解腌的脚本知道.其他StackOverflow问题(例如 AttributeError:AttributeError:模块"对象具有没有属性新人" ).

All the class definitions of the pickled objects need to be known by the script that does the unpickling. There is more on this in other StackOverflow questions (e.g. AttributeError: 'module' object has no attribute 'newperson').

您的代码正确无误,只要您正确导入cA.鉴于错误,您可能并非如此.确保您使用的是from cA import cA而不是import cA.

Your code is correct as long as you properly import cA. Given the error you're getting it may not be the case. Make sure you're using from cA import cA and not just import cA.

或者,您的模型是由其参数定义的,因此您可以选择仅腌制参数值).可以通过两种方式完成此操作,具体取决于您的观点.

Alternatively, your model is defined by its parameters so you could instead just pickle the parameter values). This could be done in two ways depending on what you point of view.

  1. 保存Theano共享变量.在这里,我们假设ca.params是Theano共享变量实例的常规Python列表.

  1. Save the Theano shared variables. Here we assume that ca.params is a regular Python list of Theano shared variable instances.

cPickle.dump(ca.params, f, protocol=cPickle.HIGHEST_PROTOCOL)

  • 保存Theano共享变量中存储的numpy数组.

  • Save the numpy arrays stored inside the Theano shared variables.

    cPickle.dump([param.get_value() for param in ca.params], f, protocol=cPickle.HIGHEST_PROTOCOL)
    

  • 要加载模型时,需要重新初始化参数.例如,创建cA类的新实例,然后选择

    When you want to load the model you'll need to reinitialize the parameters. For example, create a new instance of the cA class then either

    ca.params = cPickle.load(f)
    ca.W, ca.b, ca.b_prime = ca.params
    

    ca.params = [theano.shared(param) for param in cPickle.load(f)]
    ca.W, ca.b, ca.b_prime = ca.params
    

    请注意,您需要同时设置params字段和单独的参数字段.

    Note that you need to set both the params field and the separate parameters fields.

    这篇关于如何在theano中保存/序列化受过训练的模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

    查看全文
    登录 关闭
    扫码关注1秒登录
    发送“验证码”获取 | 15天全站免登陆