首页 > 解决方案 > 如何在张量流中执行具有不同秩和外部维度的张量的三对角矩阵的乘法

问题描述

下面的代码(修改后的 tensorflow 示例)会产生错误“所有输入张量必须具有相同的等级。”。tf.linalg.LinearOperatorTridiag 的多重操作也会给出类似的错误。我需要将输入乘以 Keras 层中的三对角矩阵,并且由于层输入中的附加批次维度,张量的等级不同。任何已知的实用解决方案?

import tensorflow as tf

superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [superdiag, maindiag, subdiag]
rhs = tf.constant([[[1, 1], [1, 1], [1, 1]]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')

标签: pythontensorflowkeras

解决方案


你必须扩展第一个维度

superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [tf.expand_dims(superdiag,0), tf.expand_dims(maindiag,0), tf.expand_dims(subdiag,0)]
rhs = tf.constant([[[1, 1], [1, 1], [1, 1]]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')

推荐阅读