读取.h5文件速度极慢 [英] Reading .h5 file is extremely slow

查看:0
本文介绍了读取.h5文件速度极慢的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的数据以.h5格式存储。我使用数据生成器来拟合模型,它非常慢。下面提供了我的代码片段。

def open_data_file(filename, readwrite="r"):
    return tables.open_file(filename, readwrite)

data_file_opened = open_data_file(os.path.abspath("../data/data.h5"))

train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        ......)

其中:

def get_training_and_validation_generators(data_file, batch_size, ...):
    training_generator = data_generator(data_file, training_list,....)

DATA_GENERATOR函数如下:

def data_generator(data_file, index_list,....):
      orig_index_list = index_list
    while True:
        x_list = list()
        y_list = list()
        if patch_shape:
            index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
                                                 patch_overlap, patch_start_offset,pred_specific=pred_specific)
        else:
            index_list = copy.copy(orig_index_list)

        while len(index_list) > 0:
            index = index_list.pop()
            add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     skip_blank=skip_blank, permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
                x_list = list()
                y_list = list()

Add_Data()如下:

def add_data(x_list, y_list, data_file, index, augment=False, augment_flip=False, augment_distortion_factor=0.25,
             patch_shape=False, skip_blank=True, permute=False):
    '''
    add qualified x,y to the generator list
    '''
#     pdb.set_trace()
    data, truth = get_data_from_file(data_file, index, patch_shape=patch_shape)
    
    if np.sum(truth) == 0:
        return
    if augment:
        affine = np.load('affine.npy')
        data, truth = augment_data(data, truth, affine, flip=augment_flip, scale_deviation=augment_distortion_factor)

    if permute:
        if data.shape[-3] != data.shape[-2] or data.shape[-2] != data.shape[-1]:
            raise ValueError("To utilize permutations, data array must be in 3D cube shape with all dimensions having "
                             "the same length.")
        data, truth = random_permutation_x_y(data, truth[np.newaxis])
    else:
        truth = truth[np.newaxis]

    if not skip_blank or np.any(truth != 0):
        x_list.append(data)
        y_list.append(truth)

模型培训:

def train_model(model, model_file,....):
    model.fit(training_generator,
                        steps_per_epoch=steps_per_epoch,
                        epochs=n_epochs,
                        verbose = 2,
                        validation_data=validation_generator,
                        validation_steps=validation_steps)

我的数据集很大:data.h5是55 GB。大约需要700s才能完成一个时代。在大概6个时期之后,我得到了一个分割错误。批处理大小设置为1,否则会出现资源耗尽错误。有没有一种有效的方法来读取生成器中的data.h5,以便训练更快并且不会导致内存不足错误?

推荐答案

这是我答案的开始。我查看了您的代码,您有很多调用来读取.h5数据。根据我的统计,生成器对training_listvalidation_list的每个循环进行6次读取调用。所以,在一个训练循环中,这几乎是2万个呼叫。(我)不清楚是否在每个训练循环中都调用了发电机。如果是,则乘以2268个循环。

HDF5文件读取的效率取决于读取数据的调用次数(而不仅仅是数据量)。换句话说,在一次调用中读取1 GB的数据比一次读取1000个调用x 1MB的相同数据要快。因此,我们首先需要确定从HDF5文件读取数据所花费的时间(与您的7000相比)。

我隔离了读取数据文件的PyTables调用。在此基础上,我构建了一个简单的程序来模拟您的生成函数的行为。目前,它在整个样本列表上进行单个训练循环。如果希望运行更长时间的测试,请增加n_trainn_epoch值。(注:代码语法正确。但是没有文件,所以无法验证逻辑。我认为这是正确的,但您可能需要修复一些小错误。)

请参阅以下代码。它应该独立运行(所有依赖项都已导入)。 它打印基本的计时数据。运行它以对您的发电机进行基准测试。

import tables as tb
import numpy as np
from random import shuffle 
import time

with tb.open_file('../data/data.h5', 'r') as data_file:

    n_train = 1
    n_epochs = 1
    loops = n_train*n_epochs
    
    for e_cnt in range(loops):  
        nb_samples = data_file.root.truth.shape[0]
        sample_list = list(range(nb_samples))
        shuffle(sample_list)
        split = 0.80
        n_training = int(len(sample_list) * split)
        training_list = sample_list[:n_training]
        validation_list = sample_list[n_training:]
        
        start = time.time()
        for index_list in [ training_list, validation_list ]:
            shuffle(index_list)
            x_list = list()
            y_list = list()
            
            while len(index_list) > 0:
                index = index_list.pop() 
                
                brain_width = data_file.root.brain_width[index]
                x = np.array([modality_img[index,0,
                                           brain_width[0,0]:brain_width[1,0]+1,
                                           brain_width[0,1]:brain_width[1,1]+1,
                                           brain_width[0,2]:brain_width[1,2]+1] 
                              for modality_img in [data_file.root.t1,
                                                   data_file.root.t1ce,
                                                   data_file.root.flair,
                                                   data_file.root.t2]])
                y = data_file.root.truth[index, 0,
                                         brain_width[0,0]:brain_width[1,0]+1,
                                         brain_width[0,1]:brain_width[1,1]+1,
                                         brain_width[0,2]:brain_width[1,2]+1]    
                
                x_list.append(data)
                y_list.append(truth)
    
        print(f'For loop:{e_cnt}')
        print(f'Time to read all data={time.time()-start:.2f}')

这篇关于读取.h5文件速度极慢的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆