Tensorflow - 从 Tensor 中提取字符串 [英] Tensorflow - Extract string from Tensor

查看:74
本文介绍了Tensorflow - 从 Tensor 中提取字符串的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试遵循 this 的使用 tf.data 加载"部分教程.在本教程中,他们可以只使用字符串张量,但是,我需要提取文件名的字符串表示,因为我需要从字典中查找额外的数据.我似乎无法提取张量的字符串部分.我很确定张量的 .name 属性应该返回字符串,但我不断收到一条错误消息,说 KeyError: 'strided_slice_1:0' 所以不知何故,切片正在做一些奇怪的事情?

I'm trying to follow the "Load using tf.data" part of this tutorial. In the tutorial, they can get away with only working with string Tensors, however, I need to extract the string representation of the filename, as I need to look up extra data from a dictionary. I can't seem to extract the string part of a Tensor. I'm pretty sure the .name attribute of a Tensor should return the string, but I keep getting an error message saying KeyError: 'strided_slice_1:0' so somehow, the slicing is doing something weird?

我正在使用以下方法加载数据集:

I'm loading the dataset using:

dataset_list = tf.data.Dataset.list_files(str(DATASET_DIR / "data/*"))

然后使用:

def process(t):
    return dataset.process_image_path(t, param_data, param_min_max)

dataset_labeled = dataset_list.map(
    process, 
    num_parallel_calls=AUTOTUNE)

其中 param_dataparam_min_max 是我加载的两个字典,其中包含构建标签所需的额外数据.

where param_data and param_min_max are two dictionaries I've loaded that contains extra data that is needed to construct the label.

这些是我用来处理数据张量的三个函数(来自我的 dataset.py):

These are the three functions that I use to process the data Tensors (from my dataset.py):

def process_image_path(image_path, param_data_file, param_max_min_file):
    label = path_to_label(image_path, param_data_file, param_max_min_file)
    img = tf.io.read_file(image_path)
    img = decode_img(img)
    return (img, label)


def decode_img(img):
    """Converts an image to a 3D uint8 tensor"""
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img


def path_to_label(image_path, param_data_file, param_max_min_file):
    """Returns the NORMALIZED label (set of parameter values) of an image."""
    parts = tf.strings.split(image_path, os.path.sep)
    filename = parts[-1]  # Extract filename with extension
    filename = tf.strings.split(filename, ".")[0].name  # Extract filename
    param_data = param_data_file[filename]  # ERROR! .name above doesn't seem to return just the filename
    P = len(param_max_min_file)

    label = np.zeros(P)

    i = 0
    while i < P:
        param = param_max_min_file[i]
        umin = param["user_min"]
        umax = param["user_max"]
        sub_index = param["sub_index"]
        identifier = param["identifier"]
        node = param["node_name"]
        value = param_data[node][identifier]

        label[i] = _normalize(value[sub_index])
        i += 1

    return label

我已经验证 path_to_label() 中的 filename = tf.strings.split(filename, ".")[0] 确实返回了正确的张量,但我需要它作为一个字符串.事实证明,整个过程也很难调试,因为我在调试时无法访问属性(我收到错误消息,说 AttributeError: Tensor.name 在启用急切执行时毫无意义.).

I have verified that filename = tf.strings.split(filename, ".")[0] in path_to_label() does return the correct Tensor, but I need it as a string. The whole thing is proving difficult to debug as well, as I can't access attributes when debugging (I get errors saying AttributeError: Tensor.name is meaningless when eager execution is enabled.).

推荐答案

name 字段是张量本身的名称,而不是张量的内容.

The name field is a name for the tensor itself, not the content of the tensor.

要进行常规的 Python 字典查找,请将解析函数包装在 tf.py_func 中.

To do a regular python dictionary lookup, wrap your parsing function in tf.py_func.

import tensorflow as tf
tf.enable_eager_execution()

d = {"a": 1, "b": 3, "c": 10}
dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])

def parse(s):
  return s, d[s]
dataset = dataset.map(lambda s: tf.py_func(parse, [s], (tf.string, tf.int64)))

for element in dataset:
  print(element[1].numpy()) # prints 1, 3, 10

这篇关于Tensorflow - 从 Tensor 中提取字符串的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆