首页 > 解决方案 > 如何在张量流中沿选定轴获得等级高于 2 的张量的对角线

问题描述

我有一个形状张量,tf.shape(input)=(Batch_Size,Channels,N,N)我的目标是计算和输出包含沿轴 2 和 3 的所有对角线元素。以便tf.shape(output)=(Batch_Size,Channels,N)

有这个功能tf.diag_part(input),但它不允许我选择我想要考虑的轴。如何定义一个为我执行此操作的函数?

以下代码可以工作吗?

Batches=[]
for batch in input:
    diagonalpart=tf.diag_part(batch)
    Batches.append(diagonalpart)
output=tf.stack(Batches)

标签: pythontensorflowtensor

解决方案


tf.linalg.diag_part应该完全符合您的要求,例如

import tensorflow as tf
import numpy as np

# Input shape: (2, 2, 4, 4)
input = np.array([
                [ [[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 8, 7, 6],
                   [9, 8, 7, 6]],
                  [[5, 4, 3, 2],
                   [1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [1, 2, 3, 4]] ], 
                [ [[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 8, 7, 6],
                   [1, 2, 3, 4]],
                  [[5, 4, 3, 2],
                   [1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 8, 7, 6]] ] 
                                ])

print(tf.linalg.diag_part(input))

将输出:

tf.Tensor(
[[[1 6 7 6]
  [5 2 7 4]]

 [[1 6 7 4]
  [5 2 7 6]]], shape=(2, 2, 4), dtype=int32)

推荐阅读