用于读取稀疏数据的 TensorFlow 输入函数(采用 libsvm 格式) [英] TensorFlow input function for reading sparse data (in libsvm format)
问题描述
我是 TensorFlow 的新手,并尝试使用 Estimator API 进行一些简单的分类实验.我有一个 libsvm 格式的稀疏数据集.以下输入函数适用于小型数据集:
I'm new to TensorFlow and trying to use the Estimator API for some simple classification experiments. I have a sparse dataset in libsvm format. The following input function works for small datasets:
def libsvm_input_function(file):
def input_function():
indexes_raw = []
indicators_raw = []
values_raw = []
labels_raw = []
i=0
for line in open(file, "r"):
data = line.split(" ")
label = int(data[0])
for fea in data[1:]:
id, value = fea.split(":")
indexes_raw.append([i,int(id)])
indicators_raw.append(int(1))
values_raw.append(float(value))
labels_raw.append(label)
i=i+1
indexes = tf.SparseTensor(indices=indexes_raw,
values=indicators_raw,
dense_shape=[i, num_features])
values = tf.SparseTensor(indices=indexes_raw,
values=values_raw,
dense_shape=[i, num_features])
labels = tf.constant(labels_raw, dtype=tf.int32)
return {"indexes": indexes, "values": values}, labels
return input_function
但是,对于几 GB 大小的数据集,我收到以下错误:
However, for a dataset of a few GB size I get the following error:
ValueError: 无法创建内容大于 2GB 的张量原型.
ValueError: Cannot create a tensor proto whose content is larger than 2GB.
我怎样才能避免这个错误?我应该如何编写一个输入函数来读取中等大小的稀疏数据集(libsvm 格式)?
How can I avoid this error? How should I write an input function to read medium-sized sparse datasets (in libsvm format)?
推荐答案
我一直在使用 tensorflow.contrib.libsvm
.这是一个示例(我正在使用带有生成器的 Eager Execution)
I have been using tensorflow.contrib.libsvm
. Here's an example (i am using eager execution with generators)
import os
import tensorflow as tf
import tensorflow.contrib.libsvm as libsvm
def all_libsvm_files(folder_path):
for file in os.listdir(folder_path):
if file.endswith(".libsvm"):
yield os.path.join(folder_path, file)
def load_libsvm_dataset(path_to_folder):
return tf.data.TextLineDataset(list(all_libsvm_files(path_to_folder)))
def libsvm_iterator(path_to_folder):
dataset = load_libsvm_dataset(path_to_folder)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
yield libsvm.decode_libsvm(tf.reshape(next_element, (1,)),
num_features=666,
dtype=tf.float32,
label_dtype=tf.float32)
libsvm_iterator
在每次迭代时为您提供一个特征标签对,来自您指定的文件夹内的多个文件.
libsvm_iterator
gives you a feature-label pair back on each iteration, from multiple files inside a folder that you specify.
这篇关于用于读取稀疏数据的 TensorFlow 输入函数(采用 libsvm 格式)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!