python - 如何在张量流中执行具有不同秩和外部维度的张量的三对角矩阵的乘法
问题描述
下面的代码(修改后的 tensorflow 示例)会产生错误“所有输入张量必须具有相同的等级。”。tf.linalg.LinearOperatorTridiag 的多重操作也会给出类似的错误。我需要将输入乘以 Keras 层中的三对角矩阵,并且由于层输入中的附加批次维度,张量的等级不同。任何已知的实用解决方案?
import tensorflow as tf
superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [superdiag, maindiag, subdiag]
rhs = tf.constant([[[1, 1], [1, 1], [1, 1]]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
解决方案
你必须扩展第一个维度
superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [tf.expand_dims(superdiag,0), tf.expand_dims(maindiag,0), tf.expand_dims(subdiag,0)]
rhs = tf.constant([[[1, 1], [1, 1], [1, 1]]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
推荐阅读
- mysql - 为什么不能将数据加载到mysql中:
- javascript - 来自 Fetch() 的响应未定义
- postgresql - Autovacuum 不删除死行(并且 xmin 水平与任何会话的 xmin 都不匹配)
- c++ - 如何获取视差图坐标对应的像素颜色
- java - 拆分此文本文件的最佳方法是什么?
- c++ - Eclipse 不构建 C++ 文件
- laravel - 我收到此错误“传递给 Illuminate\Database\Grammar::parameterize() 的参数 1 必须属于该类型
- java - 我将 Google 授权重定向 URI 设置为我的实时/测试服务器上的 URI,但 Google 报告了 localhost 的重定向 URI
- angular - 无法绑定到“consoleMessages”,因为它不是“app-console”的已知属性
- python - 如何在我想要退出的循环中做一个循环