在 tensorflow.js 中对 Tensor 进行分区、屏蔽或过滤 [英] partition or mask or filter a Tensor in tensorflow.js

查看:38
本文介绍了在 tensorflow.js 中对 Tensor 进行分区、屏蔽或过滤的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有 2 个长度相同的张量,datagroupIds.我想通过 groupId 中的相应值将 data 分成几个组.例如,

I have 2 Tensors of same length, data and groupIds. I want to split data into several groups by the corresponding values in groupId. For example,

const data = tf.tensor([1,2,3,4,5]);
const groupIds = tf.tensor([0,1,1,0,0]);
// expected result: [tf.tensor([1,4,5]), tf.tensor([2,3])]

在 Tensorflow 中有 tf.dynamic_partition 正是这样做的.Tensorflow.js 似乎没有类似的方法.我还研究了掩码或过滤作为解决方法,但它们也不存在.有没有人知道如何实现这一点?

In Tensorflow there is tf.dynamic_partition which does exactly that. Tensorflow.js doesn't seem to have a similar method. I also looked into mask or filtering as work-arounds, but they don't exist either. Does anyone have an idea how to implement this?

推荐答案

要对您的张量进行分区,您可以先遍历您的 ids 张量以获取要创建的子张量的数量和该张量的索引它应该包含的元素.此信息可以存储在一个对象中,其中键是 ids 数组中的分区编号,值是索引数组.

To partition your tensor, you can first iterate over your ids tensor to get the number of subtensor to create and the index of the elements it should contain. This information can be stored in an object where the key is the number of the partition in the ids array and the value is an array of indexes.

const data = tf.tensor([6,2,8,4,5]);
const ids = tf.tensor([0,1,1,0,2]);

const data2 = tf.tensor([[6,2],[8,4], [5, 4], [6, 5]]);
const ids2 = tf.tensor([0,1,1,0]);

const filterT = (t, p) => {
  t.print()
  p.print()
  const l = p.unstack().reduce((a, b, i) => {
  const v = b.dataSync()[0]
  if (Object.keys(a).includes(v.toString())) {
    a[v].push(i)
  } else {
    a[v] = [i]
  }
  return a
}, {})

  const r = Object.keys(l).map(k => t.gather(tf.tensor1d(l[k], 'int32')))
  r.forEach(e => e.print())
}

filterT(data, ids)
filterT(data2, ids2)

<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/0.12.4/tf.js"> </script>
  </head>

  <body>
  </body>
</html>

这篇关于在 tensorflow.js 中对 Tensor 进行分区、屏蔽或过滤的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆