加载预训练的keras模型以在Google Cloud上继续进行训练 [英] Load pre-trained keras model for continued training on google cloud

查看:656
本文介绍了加载预训练的keras模型以在Google Cloud上继续进行训练的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试加载预先训练的Keras模型,以便在Google云上继续进行训练.只需在

I am trying to load a pre-trained Keras model, for continued training on google cloud. It works locally, by simply loading the discriminator and generator with

 model = load_model('myPretrainedModel.h5')

但是显然这在Google Cloud上不起作用,我尝试使用与从Google存储桶中读取训练数据的方法相同的方法,

But obviously this doesn't work on google cloud, I have tried using the same method I use to read the training data from my google storage bucket, with:

fil = "gs://mygcbucket/myPretrainedModel.h5"    
f = BytesIO(file_io.read_file_to_string(fil, binary_mode=True))
return np.load(f)

但是,这似乎不适用于加载模型,但在运行作业时出现以下错误.

However this doesn't seem to work for loading a model, I get the following error running the job.

ValueError:allow_pickle = False时,无法加载包含腌制数据的文件

ValueError: Cannot load file containing pickled data when allow_pickle=False

添加allow_pickle=True会引发另一个错误:

adding allow_pickle=True, throws another error:

OSError:无法将0x7fdf2bb42620>上的文件< _io.BytesIO对象解释为泡菜

OSError: Failed to interpret file <_io.BytesIO object at 0x7fdf2bb42620> as a pickle

然后我尝试了我发现的类似问题的建议,因为我了解它可以暂时从存储桶中本地(相对于作业的运行位置)保存模型,然后使用以下方式加载模型:

I then tried something I found as someone suggested for a similar issue, as I understand it temporarily resaving the model locally (in relation to where the job is running) from the bucket and then loading it, with:

fil = "gs://mygcbucket/myPretrainedModel.h5"  
model_file = file_io.FileIO(fil, mode='rb')
file_stream = file_io.FileIO(model_file, mode='r')
temp_model_location = './temp_model.h5'
temp_model_file = open(temp_model_location, 'wb')
temp_model_file.write(file_stream.read())
temp_model_file.close()
file_stream.close()
model = load_model(temp_model_location)
return model

但是,这会引发以下错误:

However, this throw the following error:

TypeError:预期的二进制或Unicode字符串,得到tensorflow.python.lib.io.file_io.FileIO对象

TypeError: Expected binary or unicode string, got tensorflow.python.lib.io.file_io.FileIO object

我必须承认,我不太确定从存储桶中实际加载经过预先​​训练的keras模型所需的操作,以及在Google Cloud的培训工作中的使用情况.任何帮助深表感谢.

I must admit I am not really sure what I need to do to actually load a pre-trained keras model from my storage bucket, and the use if in my training job at google cloud. Any help is deeply appreciated.

推荐答案

我建议使用AI Platform Notebooks这样做.使用此方法下载经过训练的模型.检查代码示例"选项卡下的Python代码.将模型放在运行Notebook的VM中后,就可以像在本地进行加载一样加载它. 此处,您有一个示例,其中使用 tf.keras.models.load_model .

I would suggest to use AI Platform Notebooks to do so. Download the trained model using this method. Check the Python code under the Code samples tab. Once you have your model in the VM where the Notebook is running you can load it as you were doing locally. Here you have an example where tf.keras.models.load_model is used.

这篇关于加载预训练的keras模型以在Google Cloud上继续进行训练的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆