如何检查张量的值是否包含在其他张量中? [英] How to check that the values of Tensor is contained in other tensor?

查看:85
本文介绍了如何检查张量的值是否包含在其他张量中?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在从其他张量中寻找值时遇到问题

I have a problem about finding a value from other tensor

它类似于以下问题:(网址:如何从 Tensorflow 中的其他张量中找到张量中的值)

It's similar to the following problem : (URL: How to find a value in tensor from other tensor in Tensorflow)

之前的问题是询问输入张量x[i]y[i]是否包含在输入张量label_x、<强>label_y

The previous problem was to ask if input tensor x[i], y[i] is contained in input tensor label_x, label_y

以下是上一个问题的示例:

Here is an example of the previous problem:

Input Tensor
s_idx = (1, 3, 5, 7)
e_idx = (3, 4, 5, 8)

label_s_idx = (2, 2, 3, 6)
label_e_idx = (2, 3, 4, 8)

问题是给output[i]一个值为1如果 s_idx[i] == label_s_idx[j] 和 e_idx[i] == label_s_idx[j] for some j 满足某些 j.

The problem is to give output[i] a value of 1 if s_idx[i] == label_s_idx[j] and e_idx[i] == label_s_idx[j] for some j are satisfied for some j.

因此,在上面的例子中,输出张量是

Thus, in the above example, the output tensor is

output = (0, 1, 0, 0)

因为 (s_idx[1] = 3, e_idx[1] = 4) 与 (label_s_idx[2] = 3, label_e_idx[2] = 4)

Because (s_idx[1] = 3, e_idx[1] = 4) is same as (label_s_idx[2] = 3, label_e_idx[2] = 4)

(s_idx, e_idx) 没有重复值,而 (label_s_idx, label_e_idx) 有.

(s_idx, e_idx) does not have a duplicate value, and (label_s_idx, label_e_idx) does so.

因此,假设以下输入示例是不可能的:

Therefore, it is assumed that the following input example is impossible:

s_idx = (2, 2, 3, 3)
e_idx = (2, 3, 3, 3)

因为,(s_idx[2] = 3, e_idx[2] = 3) 与 (s_idx[3] =3、e_idx[3] = 3).

Because, (s_idx[2] = 3, e_idx[2] = 3) is same as (s_idx[3] = 3, e_idx[3] = 3).

我想在这个问题中稍微改变的是给输入张量添加另一个值:

What I want to change a bit in this problem is to add another value to the input tensor:

Input Tensor
s_idx = (1, 3, 5, 7)
e_idx = (3, 4, 5, 8)

label_s_idx = (2, 2, 3, 6)
label_e_idx = (2, 3, 4, 8)
label_score = (1, 3, 2, 3)

*label_score 张量中没有 0 值

*There is no 0 values in label_score tensor

变更问题中的任务定义如下:

The task in the changed problem is defined as follows:

问题是如果 s_idx[i] == label_s_idx[j]e_idx[i] == label_s_idx[j] 给 output_2[i] 一个值 label_score[j]]对于一些 j 是满意的.

The problem is to give output_2[i] a value of label_score[j] if s_idx[i] == label_s_idx[j] and e_idx[i] == label_s_idx[j] for some j are satisfied.

因此,output_2 应该是这样的:

Therefore, the output_2 should be like this:

output = (0, 1, 0, 0)  // It is same as previous problem
output_2 = (0, 2, 0, 0)

如何在 Python 中的 Tensorflow 上编写这样的代码?

How do I code like this on Tensorflow in Python?

推荐答案

这是一个可能的解决方案:

Here is a possible solution:

import tensorflow as tf

s_idx = tf.placeholder(tf.int32, [None])
e_idx = tf.placeholder(tf.int32, [None])
label_s_idx = tf.placeholder(tf.int32, [None])
label_e_idx = tf.placeholder(tf.int32, [None])
label_score = tf.placeholder(tf.int32, [None])

# Stack inputs for comparison
se_idx = tf.stack([s_idx, e_idx], axis=1)
label_se_idx = tf.stack([label_s_idx, label_e_idx], axis=1)
# Compare every pair to each other and find matches
cmp = tf.equal(se_idx[:, tf.newaxis, :], label_se_idx[tf.newaxis, :, :])
matches = tf.reduce_all(cmp, axis=2)
# Find the position of the matches
match_pos = tf.argmax(tf.cast(matches, tf.int8), axis=1)
# For those positions where a match was found take the corresponding score
output = tf.where(tf.reduce_any(matches, axis=1),
                  tf.gather(label_score, match_pos),
                  tf.zeros_like(label_score))

# Test
with tf.Session() as sess:
    print(sess.run(output, feed_dict={s_idx: [1, 3, 5, 7],
                                      e_idx: [3, 4, 5, 8],
                                      label_s_idx: [2, 2, 3, 6],
                                      label_e_idx: [2, 3, 4, 8],
                                      label_score: [1, 3, 2, 3]}))
# >>> [0 2 0 0]

它将每对值相互比较,因此成本是输入大小的二次方.此外,tf.argmax 用于查找匹配位置的索引,如果有多个可能的索引,它可能会不确定地返回其中任何一个.

It compares every pair of values to each other, so the cost is quadratic on the input size. Also, tf.argmax is used to find the index of the matching position, and if there is more than one possible index it may return any of them nondeterministically.

这篇关于如何检查张量的值是否包含在其他张量中?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆