Tensorflow - 数据集 API 中的字符串处理 [英] Tensorflow - String processing in Dataset API
问题描述
我在
格式的目录中有 .txt
文件.我正在使用 TextLineDataset
API 来使用这些文本记录:
I have .txt
files in a directory of format <text>\t<label>
. I am using the TextLineDataset
API to consume these text records:
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.flat_map(
lambda filename: (
tf.contrib.data.TextLineDataset(filename)
.map(_parse_data)))
def _parse_data(line):
line_split = tf.string_split([line], '\t')
features = {"raw_text": tf.string(line_split.values[0].strip().lower()),
"label": tf.string_to_number(line_split.values[1],
out_type=tf.int32)}
parsed_features = tf.parse_single_example(line, features)
return parsed_features["raw_text"], raw_features["label"]
我想对 raw_text 功能进行一些字符串清理/处理.当我尝试运行 line_split.values[0].strip().lower()
时,出现以下错误:
I would like to do some string cleaning/processing on the raw_text feature. When I try to run line_split.values[0].strip().lower()
, I get the following error:
AttributeError: 'Tensor' 对象没有属性 'strip'
AttributeError: 'Tensor' object has no attribute 'strip'
推荐答案
对象 lines_split.values[0]
是一个 tf.Tensor
对象,表示从 line
开始的第 0 个分割.它不是 Python 字符串,因此它没有 .strip()
或 .lower()
方法.相反,您必须将 TensorFlow 操作应用于张量才能执行转换.
The object lines_split.values[0]
is a tf.Tensor
object representing the 0th split from line
. It is not a Python string, and so it does not have a .strip()
or .lower()
method. Instead you will have to apply TensorFlow operations to the tensor to perform the conversion.
TensorFlow 目前没有很多字符串操作,但您可以使用tf.py_func()
op 运行一些 Pythontf.Tensor
上的代码:
TensorFlow currently doesn't have very many string operations, but you can use the tf.py_func()
op to run some Python code on a tf.Tensor
:
def _parse_data(line):
line_split = tf.string_split([line], '\t')
raw_text = tf.py_func(
lambda x: x.strip().lower(), line_split.values[0], tf.string)
label = tf.string_to_number(line_split.values[1], out_type=tf.int32)
return {"raw_text": raw_text, "label": label}
请注意,问题中的代码还有一些其他问题:
Note that there are a couple of other problems with the code in the question:
- 不要使用
tf.parse_single_example()
一>.此操作仅用于解析tf.train.Example
协议缓冲区字符串;解析文本时不需要使用它,可以直接从_parse_data()
返回提取的特征. - 使用
dataset.map()
而不是dataset.flat_map()
.当映射函数的结果是Dataset
对象时,您只需要使用flat_map()
(因此返回值需要扁平化成单个数据集).当结果是一个或多个tf.Tensor
对象时,您必须使用map()
.
- Don't use
tf.parse_single_example()
. This op is only used for parsingtf.train.Example
protocol buffer strings; you do not need to use it when parsing text, and you can return the extracted features directly from_parse_data()
. - Use
dataset.map()
instead ofdataset.flat_map()
. You only need to useflat_map()
when the result of your mapping function is aDataset
object (and hence the return values need to be flattened into a single dataset). You must usemap()
when the result is one or moretf.Tensor
objects.
这篇关于Tensorflow - 数据集 API 中的字符串处理的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!