python - Tensorflow:创建一个对角矩阵,输入在子/超对角线上
问题描述
我有以下代码:
import tensorflow as tf
N = 10
X = tf.ones([N,], dtype=tf.float64)
D = tf.linalg.diag(X, k=1, num_rows=N+1, num_cols=N+1)
print(D)
其中,根据TF2 文档,我希望返回一个 11x11 张量,并在第一个超对角线上插入 X(即使没有可选参数num_rows
和num_cols
参数)。然而,结果是
tf.Tensor(
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]], shape=(10, 10), dtype=float64)
我有什么明显的遗漏吗?
解决方案
我可以告诉你为什么这不起作用,但我不知道修复是什么。可能会提出一个github问题。
如果你在array_ops.py
. 它会进行兼容性检查tf.compat.forward_compatible
以查看兼容性窗口是否已过期。返回False
(对于 TF 2.0.0 和 2.1.0rc0) 。由于这个原因,它执行
return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)
您可以看到调用时没有使用k
, num_rows
, 。因此,如果检查失败num_cols
,该方法目前完全忽略这些参数。tf.compat.forward_compatible
推荐阅读
- php - 防止 DOMDocument 转义 html
- python - Mongodb DeprecationWarning:不推荐使用计数。请改用 Collection.count_documents
- node.js - Socket.io:无法读取未定义的属性“发射”
- azerothcore - 尝试为 azerothcore 安装 soloLFG 模块时出错
- laravel - 如何在 laravel 中仅获取名称值请建议
- python - Pandas 按重复日期分组
- javascript - 如何在静态主机上部署静态 JS 应用并上传图片?
- c - 在 Debian 上编译 GTK 3 应用程序
- java - SQL 错误:1364,SQLState:HY000 - 字段“rating_id”没有默认值/保存具有一对一关系的子实体时
- python - 使用替代模块满足点子要求