在Google Cloud Bucket中保存Keras ModelCheckpoints [英] Save Keras ModelCheckpoints in Google Cloud Bucket

查看:93
本文介绍了在Google Cloud Bucket中保存Keras ModelCheckpoints的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用带有TensorFlow后端的Keras在Google Cloud Machine Learning Engine上训练LSTM网络.在对gcloud和python脚本进行一些调整之后,我设法通过它部署了我的模型并执行了成功的训练任务.

I'm working on training a LSTM network on Google Cloud Machine Learning Engine using Keras with TensorFlow backend. I managed it to deploy my model and perform a successful training task after some adjustments to the gcloud and my python script.

然后,我尝试使用Keras modelCheckpoint回调,使每个周期之后的模型都保存检查点.通过Google Cloud运行本地培训工作可以按预期完美地进行.权重将在每个时期后存储在指定的路径中.但是,当我尝试在Google Cloud Machine Learning Engine上在线运行同一作业时,weights.hdf5不会写入我的Google Cloud Bucket.相反,我收到以下错误:

I then tried to make my model save checkpoints after every epoch using Keras modelCheckpoint callback. Running a local training job with Google Cloud works perfectly as expected. The weights are getting stored in the specified path after each epoch. But when I try to run the same job online on Google Cloud Machine Learning Engine the weights.hdf5 does not get written to my Google Cloud Bucket. Instead I get the following error:

...
File "h5f.pyx", line 71, in h5py.h5f.open (h5py/h5f.c:1797)
IOError: Unable to open file (Unable to open file: name = 
'gs://.../weights.hdf5', errno = 2, error message = 'no such file or
directory', flags = 0, o_flags = 0)

我调查了这个问题,结果发现,铲斗本身没有问题,例如Keras Tensorboard callback 确实可以正常工作,并将预期的输出写入相同的存储桶.我还通过在位于以下位置的setup.py中提供它来确保包含h5py:

I investigated this issue and it turned out, that there is no Problem with the the Bucket itself, as Keras Tensorboard callback does work fine and writes the expected output to the same bucket. I also made sure that h5py gets included by providing it in the setup.py located at:

├── setup.py
    └── trainer
    ├── __init__.py
    ├── ...

setup.py中的实际包含如下所示:

The actual include in setup.py is shown below:

# setup.py
from setuptools import setup, find_packages

setup(name='kerasLSTM',
      version='0.1',
      packages=find_packages(),
      author='Kevin Katzke',
      install_requires=['keras','h5py','simplejson'],
      zip_safe=False)

我认为问题归结为以下事实:无法使用Pythons open的I/O访问GCS,因为它提供了自定义实现:

I guess the problem comes down to the fact that GCS cannot be accessed with Pythons open for I/O since it instead provides a custom implementation:

import tensorflow as tf
from tensorflow.python.lib.io import file_io

with file_io.FileIO("gs://...", 'r') as f:
    f.write("Hi!")

在检查Keras modelCheckpoint回调如何实现实际文件写入之后,事实证明,它正在使用 h5py.File():

After checking how Keras modelCheckpoint callback implements the actual file writing and it turned out, that it is using h5py.File() for I/O:

 with h5py.File(filepath, mode='w') as f:
    f.attrs['keras_version'] = str(keras_version).encode('utf8')
    f.attrs['backend'] = K.backend().encode('utf8')
    f.attrs['model_config'] = json.dumps({
        'class_name': model.__class__.__name__,
        'config': model.get_config()
 }, default=get_json_type).encode('utf8')

由于h5py packageHDF5 binary data format的Pythonic接口,据我所知,h5py.File()似乎调用了用Fortran编写的基础HDF5功能:

And as the h5py package is a Pythonic interface to the HDF5 binary data format the h5py.File() seems to call an underlying HDF5 functionality written in Fortran as far as I can tell: source, documentation.

如何解决此问题并使modelCheckpoint回调写入我的GCS存储桶?是否有一种猴子修补"方式可以某种方式覆盖hdf5文件的打开方式以使其使用GCS的file_io.FileIO()?

How can I fix this and make the modelCheckpoint callback write to my GCS Bucket? Is there a way for "monkey patching" to somehow overwrite how a hdf5 file is opened to make it use GCS's file_io.FileIO()?

推荐答案

可以使用以下代码解决问题:

The issue can be solved with the following piece of code:

# Save Keras ModelCheckpoints locally
model.save('model.h5')

# Copy model.h5 over to Google Cloud Storage
with file_io.FileIO('model.h5', mode='r') as input_f:
    with file_io.FileIO('model.h5', mode='w+') as output_f:
        output_f.write(input_f.read())
        print("Saved model.h5 to GCS")

model.h5保存在本地文件系统中,然后复制到GCS.正如 Jochen 指出的那样,目前尚不容易将HDF5模型检查点写入GCS.有了这种技巧,就可以写入数据,直到提供更简单的解决方案为止.

The model.h5 is saved on local filesystem and the copied over to GCS. As Jochen pointed out, there currently is no easy support to write HDF5 model checkpoints to GCS. With this hack it is possible to write the data until an easier solution is provided.

这篇关于在Google Cloud Bucket中保存Keras ModelCheckpoints的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆