首页 > 解决方案 > Tensorflow:创建一个对角矩阵,输入在子/超对角线上

问题描述

我有以下代码:

import tensorflow as tf

N = 10
X = tf.ones([N,], dtype=tf.float64)
D = tf.linalg.diag(X, k=1, num_rows=N+1, num_cols=N+1)

print(D)

其中,根据TF2 文档,我希望返回一个 11x11 张量,并在第一个超对角线上插入 X(即使没有可选参数num_rowsnum_cols参数)。然而,结果是

tf.Tensor(
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]], shape=(10, 10), dtype=float64)

我有什么明显的遗漏吗?

标签: pythontensorflow

解决方案


我可以告诉你为什么这不起作用,但我不知道修复是什么。可能会提出一个github问题。

如果你在array_ops.py. 它会进行兼容性检查tf.compat.forward_compatible以查看兼容性窗口是否已过期。返回False(对于 TF 2.0.0 和 2.1.0rc0) 。由于这个原因,它执行

return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)

您可以看到调用时没有使用k, num_rows, 。因此,如果检查失败num_cols,该方法目前完全忽略这些参数。tf.compat.forward_compatible


推荐阅读