首页 > 解决方案 > How to get indices of a specific value in a tensor in tensorflow-js?

问题描述

For example, If I have a 2D tensor of [[1,3],[2,1]], how can I get indices of the value 1? (It should return [[0,0],[1,1]]).

I took a look into tf.where but the API is complicated and I don't think that would solve the problem for me

标签: tensorflow.js

解决方案


You can achieve this using tf.whereAsync.

Just create a mask which checks if the values in the input Tensor are of the value 1 and turn them into boolean values.

Mask:

"Tensor
    [[true , false],
     [false, true ]]"

tf.whereAsync() returns the coordinates of true elements of the condition which in this case come from the mask.

(async function getData() {
  const x = tf.tensor2d([[1, 3], [2, 1]])

  const mask = x.equal([1]).asType('bool');
  const coords = await tf.whereAsync(mask);
  coords.print();
}());

Input:

"Tensor
    [[1, 3],
     [2, 1]]"

Output:

"Tensor
    [[0, 0],
     [1, 1]]"

推荐阅读