用于读取稀疏数据的 TensorFlow 输入函数(采用 libsvm 格式) [英] TensorFlow input function for reading sparse data (in libsvm format)

查看:30
本文介绍了用于读取稀疏数据的 TensorFlow 输入函数(采用 libsvm 格式)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是 TensorFlow 的新手,并尝试使用 Estimator API 进行一些简单的分类实验.我有一个 libsvm 格式的稀疏数据集.以下输入函数适用于小型数据集:

I'm new to TensorFlow and trying to use the Estimator API for some simple classification experiments. I have a sparse dataset in libsvm format. The following input function works for small datasets:

def libsvm_input_function(file):

    def input_function():

        indexes_raw = []
        indicators_raw = []
        values_raw = []
        labels_raw = []
        i=0

        for line in open(file, "r"):
            data = line.split(" ")
            label = int(data[0])
            for fea in data[1:]:
                id, value = fea.split(":")
                indexes_raw.append([i,int(id)])
                indicators_raw.append(int(1))
                values_raw.append(float(value))
            labels_raw.append(label)
            i=i+1

        indexes = tf.SparseTensor(indices=indexes_raw,
                              values=indicators_raw,
                              dense_shape=[i, num_features])

        values = tf.SparseTensor(indices=indexes_raw,
                             values=values_raw,
                             dense_shape=[i, num_features])

        labels = tf.constant(labels_raw, dtype=tf.int32)

        return {"indexes": indexes, "values": values}, labels

    return input_function

但是,对于几 GB 大小的数据集,我收到以下错误:

However, for a dataset of a few GB size I get the following error:

ValueError: 无法创建内容大于 2GB 的张量原型.

ValueError: Cannot create a tensor proto whose content is larger than 2GB.

我怎样才能避免这个错误?我应该如何编写一个输入函数来读取中等大小的稀疏数据集(libsvm 格式)?

How can I avoid this error? How should I write an input function to read medium-sized sparse datasets (in libsvm format)?

推荐答案

我一直在使用 tensorflow.contrib.libsvm.这是一个示例(我正在使用带有生成器的 Eager Execution)

I have been using tensorflow.contrib.libsvm. Here's an example (i am using eager execution with generators)

import os
import tensorflow as tf
import tensorflow.contrib.libsvm as libsvm


def all_libsvm_files(folder_path):
    for file in os.listdir(folder_path):
        if file.endswith(".libsvm"):
            yield os.path.join(folder_path, file)

def load_libsvm_dataset(path_to_folder):
    return tf.data.TextLineDataset(list(all_libsvm_files(path_to_folder)))


def libsvm_iterator(path_to_folder):
    dataset = load_libsvm_dataset(path_to_folder)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    yield libsvm.decode_libsvm(tf.reshape(next_element, (1,)),
                               num_features=666,
                               dtype=tf.float32,
                               label_dtype=tf.float32)

libsvm_iterator 在每次迭代时为您提供一个特征标签对,来自您指定的文件夹内的多个文件.

libsvm_iterator gives you a feature-label pair back on each iteration, from multiple files inside a folder that you specify.

这篇关于用于读取稀疏数据的 TensorFlow 输入函数(采用 libsvm 格式)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆