Tensorflow - 数据集 API 中的字符串处理 [英] Tensorflow - String processing in Dataset API

查看:46
本文介绍了Tensorflow - 数据集 API 中的字符串处理的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在 \t 格式的目录中有 .txt 文件.我正在使用 TextLineDataset API 来使用这些文本记录:

I have .txt files in a directory of format <text>\t<label>. I am using the TextLineDataset API to consume these text records:

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames)

dataset = dataset.flat_map(
    lambda filename: (
        tf.contrib.data.TextLineDataset(filename)
        .map(_parse_data)))

def _parse_data(line):   
    line_split = tf.string_split([line], '\t')
    features = {"raw_text": tf.string(line_split.values[0].strip().lower()),
                "label": tf.string_to_number(line_split.values[1], 
                    out_type=tf.int32)}
    parsed_features = tf.parse_single_example(line, features)
    return parsed_features["raw_text"], raw_features["label"]

我想对 raw_text 功能进行一些字符串清理/处理.当我尝试运行 line_split.values[0].strip().lower() 时,出现以下错误:

I would like to do some string cleaning/processing on the raw_text feature. When I try to run line_split.values[0].strip().lower(), I get the following error:

AttributeError: 'Tensor' 对象没有属性 'strip'

AttributeError: 'Tensor' object has no attribute 'strip'

推荐答案

对象 lines_split.values[0] 是一个 tf.Tensor 对象,表示从 line 开始的第 0 个分割.它不是 Python 字符串,因此它没有 .strip().lower() 方法.相反,您必须将 TensorFlow 操作应用于张量才能执行转换.

The object lines_split.values[0] is a tf.Tensor object representing the 0th split from line. It is not a Python string, and so it does not have a .strip() or .lower() method. Instead you will have to apply TensorFlow operations to the tensor to perform the conversion.

TensorFlow 目前没有很多字符串操作,但您可以使用tf.py_func() op 运行一些 Pythontf.Tensor 上的代码:

TensorFlow currently doesn't have very many string operations, but you can use the tf.py_func() op to run some Python code on a tf.Tensor:

def _parse_data(line):
    line_split = tf.string_split([line], '\t')

    raw_text = tf.py_func(
        lambda x: x.strip().lower(), line_split.values[0], tf.string)

    label = tf.string_to_number(line_split.values[1], out_type=tf.int32)

    return {"raw_text": raw_text, "label": label}

请注意,问题中的代码还有一些其他问题:

Note that there are a couple of other problems with the code in the question:

  • Don't use tf.parse_single_example(). This op is only used for parsing tf.train.Example protocol buffer strings; you do not need to use it when parsing text, and you can return the extracted features directly from _parse_data().
  • Use dataset.map() instead of dataset.flat_map(). You only need to use flat_map() when the result of your mapping function is a Dataset object (and hence the return values need to be flattened into a single dataset). You must use map() when the result is one or more tf.Tensor objects.

这篇关于Tensorflow - 数据集 API 中的字符串处理的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆