Tensorflow:将 .hdf5 转换为 tfrecord [英] Tensorflow: convert .hdf5 to tfrecord

查看:80
本文介绍了Tensorflow:将 .hdf5 转换为 tfrecord的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有 .h5 格式的 coco 数据集.我需要将它转换为 .record(TF 记录文件),以便我可以使用对象检测 API 训练我的 Embedded_ssd_mobilenet.我该怎么办?

I have coco dataset in .h5 format. I need to convert it to .record (TF record file) so that I can train my embedded_ssd_mobilenet using Object Detection API. How can I do?

推荐答案

这是我用来将一些数据从 hdf5 转换为 tfrecord 的脚本.您显然必须修改列名

Here's a script I used to convert some data from hdf5 to tfrecord. You'll obviously have to modify the column names

import h5py
import os
import tensorflow as tf

CHUNK_SIZE = 5000
COL_NAMES = {
    "index": lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=[x])),
    "vals": lambda x:     tf.train.Feature(float_list=tf.train.FloatList(value=x.reshape(-1))),
    "vals_shape": lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=list(x.shape))),
}
DATASET_NAME = "data"

def write_tfrecords(file_path, features_list):
    with tf.python_io.TFRecordWriter(file_path) as writer:
        for features in features_list:
            writer.write(tf.train.Example(features=features).SerializeToString())

def hdf5_row_to_features(hdf5_row):
    feature_dict = dict()
    for col_name in COL_NAMES.keys():
        if col_name == "index" and hdf5_row["index"] % 100 == 0:
            print("index: %d" % hdf5_row["index"])
        feature_dict[col_name] = COL_NAMES[col_name](hdf5_row[col_name])
        if col_name == "vals":
            feature_dict["vals_shape"] = COL_NAMES["vals_shape"]    (hdf5_row[col_name])
    return tf.train.Features(feature=feature_dict)

def convert_records(file_path):
    dir_name = os.path.dirname(file_path)
    base_file_name = os.path.splitext(file_path)[0]
    tfrecord_file_name_template = "%s-%d.tfrecord"
    tfrecord_file_counter = 0

    hdf5_file = h5py.File(file_path, "r")
    features_list = list()
    index = 0
    print("Dataset size: %d" % hdf5_file[DATASET_NAME].size)
    while index < hdf5_file[DATASET_NAME].size:
        if index % 100 == 0:
            print("iteration index: %d" % index)
        features = hdf5_row_to_features(hdf5_file[DATASET_NAME][index])
        features_list.append(features)

        # Write chunk to file.
        if index % CHUNK_SIZE == 0 and index != 0:
            write_tfrecords(
                os.path.join(
                    # dir_name,
                    tfrecord_file_name_template % (base_file_name, tfrecord_file_counter)),
                features_list)
            tfrecord_file_counter += 1
            features_list = list()
        index += 1

    # Write remainder to file.
    if index % CHUNK_SIZE != 0:
        write_tfrecords(
            os.path.join(
                # dir_name,
                tfrecord_file_name_template % (base_file_name,     tfrecord_file_counter)),
            features_list)

    print("Dataset size: %d" % hdf5_file[DATASET_NAME].size)
    hdf5_file.close()

您还可以使用以下代码段检查 tfrecord 文件中的数据,以检查在开发过程中是否正确写入了所有内容:

You can also inspect the data in your tfrecord file to check that everything was written correctly during development with this snippet:

def inspect_tf_record_file(file_path, result_chunking=1):
    count = 0
    for example in tf.python_io.tf_record_iterator(file_path):
        result = tf.train.Example.FromString(example)
        if count % result_chunking == 0:
            print("result: %s" % result)
        count += 1
    print("Total count: %d" % count)

这篇关于Tensorflow:将 .hdf5 转换为 tfrecord的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆