使用MXnet时如何保存模型 [英] How to save a model when using MXnet

查看:508
本文介绍了使用MXnet时如何保存模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用MXnet来训练CNN(在R中),并且可以使用以下代码来训练模型而没有任何错误:

I am using MXnet for training a CNN (in R) and I can train the model without any error with the following code:

model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )

但是,由于此过程很耗时,因此我在夜间在服务器上运行它,我想保存该模型以供完成后使用

But as this process is time-consuming, I run it on a server during the night and I want to save the model for the purpose of using it after finishing the training.

我用过:

save(list = ls(), file="mymodel.RData")

mx.model.save("mymodel", 10)

但是它们都无法保存模型!例如,当我加载 mymodel.RData 时,我无法预测测试集的标签!

But none of them can save the model! for example when I load the "mymodel.RData", I can not predict the labels for the test set!

另一个例子是,当我加载 mymodel.RData 并尝试使用以下代码对其进行绘制时:

Another example is when I load the "mymodel.RData" and try to plot it with the following code:

graph.viz(model$symbol$as.json())

我收到以下错误:

Error in model$symbol$as.json() : external pointer is not valid

有人可以给我一个保存然后加载此模型以供将来使用的解决方案吗?

Can anybody give me a solution for saving and then loading this model for future use?

谢谢

推荐答案

您可以通过

model <- mx.model.FeedForward.create(symbol=network,
                                 X=train.iter,
                                 ctx=mx.gpu(0),
                                 num.round=20,
                                 array.batch.size=batch.size,
                                 learning.rate=0.1,
                                 momentum=0.1,  
                                 eval.metric=mx.metric.accuracy,
                                 wd=0.001,
                                 epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
                                 batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)

这篇关于使用MXnet时如何保存模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆