在 Google Cloud Bucket 中保存 Keras ModelCheckpoint [英] Save Keras ModelCheckpoints in Google Cloud Bucket

查看:41
本文介绍了在 Google Cloud Bucket 中保存 Keras ModelCheckpoint的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 Keras 和 TensorFlow 后端在 Google Cloud 机器学习引擎上训练 LSTM 网络.在对 gcloud 和我的 python 脚本进行一些调整后,我设法部署了我的模型并执行了成功的训练任务.

I'm working on training a LSTM network on Google Cloud Machine Learning Engine using Keras with TensorFlow backend. I managed it to deploy my model and perform a successful training task after some adjustments to the gcloud and my python script.

然后我尝试使用 Keras modelCheckpoint 回调让我的模型在每个时期之后保存检查点.使用 Google Cloud 运行本地训练作业按预期完美运行.每个 epoch 之后,权重都存储在指定的路径中.但是,当我尝试在 Google Cloud Machine Learning Engine 上在线运行相同的作业时,weights.hdf5 不会写入我的 Google Cloud Bucket.相反,我收到以下错误:

I then tried to make my model save checkpoints after every epoch using Keras modelCheckpoint callback. Running a local training job with Google Cloud works perfectly as expected. The weights are getting stored in the specified path after each epoch. But when I try to run the same job online on Google Cloud Machine Learning Engine the weights.hdf5 does not get written to my Google Cloud Bucket. Instead I get the following error:

...
File "h5f.pyx", line 71, in h5py.h5f.open (h5py/h5f.c:1797)
IOError: Unable to open file (Unable to open file: name = 
'gs://.../weights.hdf5', errno = 2, error message = 'no such file or
directory', flags = 0, o_flags = 0)

我调查了这个问题,结果发现 Bucket 本身没有问题,因为 Keras Tensorboard回调 确实工作正常并将预期的输出写入同一个存储桶.我还通过在位于以下位置的 setup.py 中提供 h5py 来确保包含它:

I investigated this issue and it turned out, that there is no Problem with the the Bucket itself, as Keras Tensorboard callback does work fine and writes the expected output to the same bucket. I also made sure that h5py gets included by providing it in the setup.py located at:

├── setup.py
    └── trainer
    ├── __init__.py
    ├── ...

setup.py 中的实际包含如下所示:

The actual include in setup.py is shown below:

# setup.py
from setuptools import setup, find_packages

setup(name='kerasLSTM',
      version='0.1',
      packages=find_packages(),
      author='Kevin Katzke',
      install_requires=['keras','h5py','simplejson'],
      zip_safe=False)

我想问题归结为这样一个事实,即 GCS 无法通过 Python open 进行 I/O 访问,因为它提供了自定义实现:

I guess the problem comes down to the fact that GCS cannot be accessed with Pythons open for I/O since it instead provides a custom implementation:

import tensorflow as tf
from tensorflow.python.lib.io import file_io

with file_io.FileIO("gs://...", 'r') as f:
    f.write("Hi!")

在检查 Keras modelCheckpoint 回调如何实现实际文件写入后,结果证明它使用的是 h5py.File():

After checking how Keras modelCheckpoint callback implements the actual file writing and it turned out, that it is using h5py.File() for I/O:

 with h5py.File(filepath, mode='w') as f:
    f.attrs['keras_version'] = str(keras_version).encode('utf8')
    f.attrs['backend'] = K.backend().encode('utf8')
    f.attrs['model_config'] = json.dumps({
        'class_name': model.__class__.__name__,
        'config': model.get_config()
 }, default=get_json_type).encode('utf8')

由于 h5py 包HDF5 二进制数据格式 的 Pythonic 接口,h5py.File() 似乎调用了一个底层据我所知,HDF5 用 Fortran 编写的功能:来源文档.

And as the h5py package is a Pythonic interface to the HDF5 binary data format the h5py.File() seems to call an underlying HDF5 functionality written in Fortran as far as I can tell: source, documentation.

我该如何解决这个问题并使 modelCheckpoint 回调写入我的 GCS Bucket?有没有办法让猴子补丁"以某种方式覆盖打开 hdf5 文件的方式,使其使用 GCS 的 file_io.FileIO()?

How can I fix this and make the modelCheckpoint callback write to my GCS Bucket? Is there a way for "monkey patching" to somehow overwrite how a hdf5 file is opened to make it use GCS's file_io.FileIO()?

推荐答案

可以用下面一段代码解决这个问题:

The issue can be solved with the following piece of code:

# Save Keras ModelCheckpoints locally
model.save('model.h5')

# Copy model.h5 over to Google Cloud Storage
with file_io.FileIO('model.h5', mode='r') as input_f:
    with file_io.FileIO('model.h5', mode='w+') as output_f:
        output_f.write(input_f.read())
        print("Saved model.h5 to GCS")

model.h5 保存在本地文件系统中并复制到 GCS.正如 Jochen 指出的那样,目前没有简单的支持将 HDF5 模型检查点写入 GCS.有了这个 hack,就可以写入数据,直到提供更简单的解决方案.

The model.h5 is saved on local filesystem and the copied over to GCS. As Jochen pointed out, there currently is no easy support to write HDF5 model checkpoints to GCS. With this hack it is possible to write the data until an easier solution is provided.

这篇关于在 Google Cloud Bucket 中保存 Keras ModelCheckpoint的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆