python - 如何在张量流中沿选定轴获得等级高于 2 的张量的对角线
问题描述
我有一个形状张量,tf.shape(input)=(Batch_Size,Channels,N,N)
我的目标是计算和输出包含沿轴 2 和 3 的所有对角线元素。以便tf.shape(output)=(Batch_Size,Channels,N)
有这个功能tf.diag_part(input)
,但它不允许我选择我想要考虑的轴。如何定义一个为我执行此操作的函数?
以下代码可以工作吗?
Batches=[]
for batch in input:
diagonalpart=tf.diag_part(batch)
Batches.append(diagonalpart)
output=tf.stack(Batches)
解决方案
tf.linalg.diag_part应该完全符合您的要求,例如:
import tensorflow as tf
import numpy as np
# Input shape: (2, 2, 4, 4)
input = np.array([
[ [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8],
[1, 2, 3, 4]] ],
[ [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6],
[1, 2, 3, 4]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6]] ]
])
print(tf.linalg.diag_part(input))
将输出:
tf.Tensor(
[[[1 6 7 6]
[5 2 7 4]]
[[1 6 7 4]
[5 2 7 6]]], shape=(2, 2, 4), dtype=int32)
推荐阅读
- google-analytics - 如何更改谷歌分析中的网站网址路径?
- java - Java 执行时间不加起来
- sql-server - 从 Excel 粘贴时 SQL Server 显示 Null,即使单元格中有值
- mysql - 无法从终端登录 mysql 上的任何帐户
- visual-studio - 是否可以在 Developer Powershell Terminal 中更改字体?
- docker-compose - 如何在启动时在 kong docker 容器中运行 curl 命令?
- cordova - 内容安全策略框架祖先错误地阻止来自文件系统的 safari
- javascript - 为什么我的 Javascript 不为该文本设置动画以更改颜色?
- java - 如何使用 MPAndroidChart 检索 JSON 数据并绘制 LineChart 并在 android 上进行改造
- gpu - 模拟低功耗 GPU 以帮助生成最低要求